@@ -2121,7 +2121,7 @@ typedef void (flash_attn_ext_f16_t)(
2121
2121
ushort sgitg[[simdgroup_index_in_threadgroup]]);
2122
2122
2123
2123
// ref: https://arxiv.org/pdf/2307.08691.pdf
2124
- template <int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
2124
+ template <int64_t D, int64_t Q = 8 , int64_t C = 32 > // head size, queries per threadgroup, cache items per threadgroup
2125
2125
kernel void kernel_flash_attn_ext_f16 (
2126
2126
device const char * q,
2127
2127
device const char * k,
@@ -2178,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16(
2178
2178
threadgroup float * ss = (threadgroup float *) (shared + 2 *sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
2179
2179
2180
2180
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2181
- simdgroup_half8x8 lo[Q8][ D8];
2181
+ simdgroup_half8x8 lo[D8];
2182
2182
2183
2183
// load heads from Q to shared memory
2184
2184
for (short j = sgitg; j < Q; j += nsg) {
@@ -2194,10 +2194,8 @@ kernel void kernel_flash_attn_ext_f16(
2194
2194
}
2195
2195
2196
2196
// zero out lo
2197
- for (short j = 0 ; j < Q8; ++j) {
2198
- for (short i = 0 ; i < D8; ++i) {
2199
- lo[j][i] = make_filled_simdgroup_matrix<half, 8 >(0 .0h);
2200
- }
2197
+ for (short i = 0 ; i < D8; ++i) {
2198
+ lo[i] = make_filled_simdgroup_matrix<half, 8 >(0 .0h);
2201
2199
}
2202
2200
2203
2201
// zero out shared memory SH
@@ -2229,20 +2227,18 @@ kernel void kernel_flash_attn_ext_f16(
2229
2227
const short rv3 = ne03/ne23;
2230
2228
2231
2229
// k indices
2232
- const short ik2 = iq2 / rk2;
2233
- const short ik3 = iq3 / rk3;
2230
+ const short ik2 = iq2/ rk2;
2231
+ const short ik3 = iq3/ rk3;
2234
2232
2235
2233
// v indices
2236
- const short iv2 = iq2 / rv2;
2237
- const short iv3 = iq3 / rv3;
2234
+ const short iv2 = iq2/ rv2;
2235
+ const short iv3 = iq3/ rv3;
2238
2236
2239
2237
// load the queries from shared memory into local memory
2240
- simdgroup_half8x8 mq[Q8][ D8];
2238
+ simdgroup_half8x8 mq[D8];
2241
2239
2242
- for (short j = 0 ; j < Q8; ++j) {
2243
- for (short i = 0 ; i < D8; ++i) {
2244
- simdgroup_load (mq[j][i], sq + 8 *j*T + i*8 , T);
2245
- }
2240
+ for (short i = 0 ; i < D8; ++i) {
2241
+ simdgroup_load (mq[i], sq + i*8 , T);
2246
2242
}
2247
2243
2248
2244
// pointer to the mask
@@ -2262,38 +2258,31 @@ kernel void kernel_flash_attn_ext_f16(
2262
2258
// Q*K^T
2263
2259
{
2264
2260
for (short cc = 0 ; cc < C/8 ; ++cc) {
2265
- simdgroup_float8x8 mqk[Q8];
2266
- for (short j = 0 ; j < Q8; ++j) {
2267
- mqk[j] = make_filled_simdgroup_matrix<float , 8 >(0 .h );
2268
- }
2261
+ simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float , 8 >(0 .h );
2269
2262
2270
2263
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8 *cc)*nb11 + ik2*nb12 + ik3*nb13));
2271
2264
2272
2265
for (short i = 0 ; i < D8; ++i) {
2273
2266
simdgroup_half8x8 mk;
2274
2267
simdgroup_load (mk, pk + i*8 , nb11/sizeof (half), 0 , true ); // transpose
2275
2268
2276
- for (short j = 0 ; j < Q8; ++j) {
2277
- simdgroup_multiply_accumulate (mqk[j], mq[j][i], mk, mqk[j]);
2278
- }
2269
+ simdgroup_multiply_accumulate (mqk, mq[i], mk, mqk);
2279
2270
}
2280
2271
2281
2272
// mqk = mqk*scale + mask
2282
- for (short j = 0 ; j < Q8; ++j) {
2283
- simdgroup_half8x8 mm;
2284
- simdgroup_load (mm, mp + 8 *j*(nb31/sizeof (half)) + ic + 8 *cc, nb31/sizeof (half), 0 , false );
2285
- simdgroup_multiply_accumulate (mqk[j], mqk[j], mscale, mm);
2273
+ simdgroup_half8x8 mm;
2274
+ simdgroup_load (mm, mp + ic + 8 *cc, nb31/sizeof (half), 0 , false );
2275
+ simdgroup_multiply_accumulate (mqk, mqk, mscale, mm);
2286
2276
2287
- simdgroup_store (mqk[j], ss + 8 *j*TF + 8 *cc, TF, 0 , false );
2288
- }
2277
+ simdgroup_store (mqk, ss + 8 *cc, TF, 0 , false );
2289
2278
}
2290
2279
}
2291
2280
2292
2281
// used to detect blocks full of -INF
2293
2282
float smax = -INFINITY;
2294
2283
2295
2284
// online softmax
2296
- if (C == 32 ) {
2285
+ {
2297
2286
float ms[Q];
2298
2287
2299
2288
for (short j = 0 ; j < Q; ++j) {
@@ -2314,45 +2303,6 @@ kernel void kernel_flash_attn_ext_f16(
2314
2303
ss[j*TF + p] = vs;
2315
2304
}
2316
2305
2317
- // create a QxQ diagonal matrix for rescaling the output
2318
- if (tiisg < Q) {
2319
- ss[tiisg*TF + C + tiisg] = ms[tiisg];
2320
- }
2321
- } else {
2322
- float ms[Q];
2323
-
2324
- for (short j = 0 ; j < Q; ++j) {
2325
- const float m = M[j];
2326
-
2327
- for (short p = tiisg; p < C; p += NW) {
2328
- const float s = ss[j*TF + p];
2329
-
2330
- smax = max (smax, s);
2331
- M[j] = max (M[j], s);
2332
- }
2333
-
2334
- smax = simd_max (smax);
2335
- M[j] = simd_max (M[j]);
2336
-
2337
- ms[j] = exp (m - M[j]);
2338
-
2339
- // local sum
2340
- float ls = 0 .0h;
2341
-
2342
- for (short p = tiisg; p < C; p += NW) {
2343
- const float s = ss[j*TF + p];
2344
-
2345
- const float vs = exp (s - M[j]);
2346
-
2347
- ls += vs;
2348
-
2349
- // the P matrix from the paper (Q rows, C columns)
2350
- ss[j*TF + p] = vs;
2351
- }
2352
-
2353
- S[j] = S[j]*ms[j] + simd_sum (ls);
2354
- }
2355
-
2356
2306
// create a QxQ diagonal matrix for rescaling the output
2357
2307
if (tiisg < Q) {
2358
2308
ss[tiisg*TF + C + tiisg] = ms[tiisg];
@@ -2365,12 +2315,12 @@ kernel void kernel_flash_attn_ext_f16(
2365
2315
}
2366
2316
2367
2317
// O = diag(ms)*O
2368
- for ( short j = 0 ; j < Q8; ++j) {
2318
+ {
2369
2319
simdgroup_float8x8 mm;
2370
- simdgroup_load (mm, ss + 8 *j*TF + C + 8 *j , TF, 0 , false );
2320
+ simdgroup_load (mm, ss + C , TF, 0 , false );
2371
2321
2372
2322
for (short i = 0 ; i < D8; ++i) {
2373
- simdgroup_multiply (lo[j][ i], mm, lo[j] [i]);
2323
+ simdgroup_multiply (lo[i], mm, lo[i]);
2374
2324
}
2375
2325
}
2376
2326
@@ -2383,12 +2333,10 @@ kernel void kernel_flash_attn_ext_f16(
2383
2333
simdgroup_half8x8 mk;
2384
2334
simdgroup_load (mk, pv + i*8 , nb21/sizeof (half), 0 , false );
2385
2335
2386
- for (short j = 0 ; j < Q8; ++j) {
2387
- simdgroup_float8x8 mv;
2388
- simdgroup_load (mv, ss + 8 *j*TF + 8 *cc, TF, 0 , false );
2336
+ simdgroup_float8x8 mv;
2337
+ simdgroup_load (mv, ss + 8 *cc, TF, 0 , false );
2389
2338
2390
- simdgroup_multiply_accumulate (lo[j][i], mv, mk, lo[j][i]);
2391
- }
2339
+ simdgroup_multiply_accumulate (lo[i], mv, mk, lo[i]);
2392
2340
}
2393
2341
}
2394
2342
}
@@ -2412,10 +2360,8 @@ kernel void kernel_flash_attn_ext_f16(
2412
2360
2413
2361
// each simdgroup stores its output to shared memory, reusing sq
2414
2362
if (sgitg == sg) {
2415
- for (short j = 0 ; j < Q8; ++j) {
2416
- for (short i = 0 ; i < D8; ++i) {
2417
- simdgroup_store (lo[j][i], sq + 8 *j*T + i*8 , T, 0 , false );
2418
- }
2363
+ for (short i = 0 ; i < D8; ++i) {
2364
+ simdgroup_store (lo[i], sq + i*8 , T, 0 , false );
2419
2365
}
2420
2366
}
2421
2367
@@ -2447,30 +2393,28 @@ kernel void kernel_flash_attn_ext_f16(
2447
2393
}
2448
2394
2449
2395
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2450
- for ( short j = 0 ; j < Q8; ++j) {
2396
+ {
2451
2397
simdgroup_half8x8 t;
2452
2398
simdgroup_float8x8 ms0;
2453
2399
simdgroup_float8x8 ms1;
2454
2400
2455
- simdgroup_load (ms0, ss + 8 *j*TF + C + 8 *j , TF, 0 , false );
2456
- simdgroup_load (ms1, ss + 8 *j*TF + C + 8 *j + sg*SH, TF, 0 , false );
2401
+ simdgroup_load (ms0, ss + C , TF, 0 , false );
2402
+ simdgroup_load (ms1, ss + C + sg*SH, TF, 0 , false );
2457
2403
2458
2404
for (short i = 0 ; i < D8; ++i) {
2459
- simdgroup_load (t, sq + 8 *j*T + i*8 , T, 0 , false );
2405
+ simdgroup_load (t, sq + i*8 , T, 0 , false );
2460
2406
simdgroup_multiply (t, ms1, t);
2461
2407
2462
- simdgroup_multiply_accumulate (lo[j][ i], ms0, lo[j] [i], t);
2408
+ simdgroup_multiply_accumulate (lo[i], ms0, lo[i], t);
2463
2409
}
2464
2410
}
2465
2411
}
2466
2412
}
2467
2413
2468
2414
// store result to shared memory (reuse sq)
2469
2415
if (sgitg == 0 ) {
2470
- for (short j = 0 ; j < Q8; ++j) {
2471
- for (short i = 0 ; i < D8; ++i) {
2472
- simdgroup_store (lo[j][i], sq + 8 *j*T + i*8 , T, 0 , false );
2473
- }
2416
+ for (short i = 0 ; i < D8; ++i) {
2417
+ simdgroup_store (lo[i], sq + i*8 , T, 0 , false );
2474
2418
}
2475
2419
}
2476
2420
@@ -2488,14 +2432,14 @@ kernel void kernel_flash_attn_ext_f16(
2488
2432
}
2489
2433
}
2490
2434
2491
- template [[host_name(" kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64 , 8 , 32 >;
2492
- template [[host_name(" kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80 , 8 , 32 >;
2493
- template [[host_name(" kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96 , 8 , 32 >;
2494
- template [[host_name(" kernel_flash_attn_ext_f16_h112" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112 , 8 , 32 >;
2495
- template [[host_name(" kernel_flash_attn_ext_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128 , 8 , 32 >;
2496
- template [[host_name(" kernel_flash_attn_ext_f16_h256" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256 , 8 , 32 >;
2435
+ template [[host_name(" kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64 >;
2436
+ template [[host_name(" kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80 >;
2437
+ template [[host_name(" kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96 >;
2438
+ template [[host_name(" kernel_flash_attn_ext_f16_h112" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112 >;
2439
+ template [[host_name(" kernel_flash_attn_ext_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128 >;
2440
+ template [[host_name(" kernel_flash_attn_ext_f16_h256" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256 >;
2497
2441
2498
- template <int64_t D, int64_t C > // head size, queries per threadgroup, cache items per threadgroup
2442
+ template <int64_t D, int64_t Q = 1 , int64_t C = 32 > // head size, queries per threadgroup, cache items per threadgroup
2499
2443
kernel void kernel_flash_attn_ext_vec_f16 (
2500
2444
device const char * q,
2501
2445
device const char * k,
@@ -2539,7 +2483,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2539
2483
2540
2484
const short D4 = D/4 ;
2541
2485
const short NW = N_SIMDWIDTH;
2542
- const short SH = (C + 1 ); // shared memory per simdgroup in (half)
2486
+ const short SH = (C + Q ); // shared memory per simdgroup in (half)
2543
2487
2544
2488
const short T = D + 2 *nsg*SH; // shared memory size per query in (half)
2545
2489
@@ -2763,8 +2707,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
2763
2707
}
2764
2708
}
2765
2709
2766
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 , 32 >;
2767
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256 , 32 >;
2710
+ template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 >;
2711
+ template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256 >;
2768
2712
2769
2713
kernel void kernel_cpy_f16_f16 (
2770
2714
device const half * src0,
0 commit comments