Skip to content

Commit 1a88565

Browse files
committed
metal : clean-up kernel code
1 parent 97eaece commit 1a88565

File tree

1 file changed

+43
-99
lines changed

1 file changed

+43
-99
lines changed

ggml-metal.metal

Lines changed: 43 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,7 +2121,7 @@ typedef void (flash_attn_ext_f16_t)(
21212121
ushort sgitg[[simdgroup_index_in_threadgroup]]);
21222122

21232123
// ref: https://arxiv.org/pdf/2307.08691.pdf
2124-
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
2124+
template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
21252125
kernel void kernel_flash_attn_ext_f16(
21262126
device const char * q,
21272127
device const char * k,
@@ -2178,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16(
21782178
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
21792179

21802180
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2181-
simdgroup_half8x8 lo[Q8][D8];
2181+
simdgroup_half8x8 lo[D8];
21822182

21832183
// load heads from Q to shared memory
21842184
for (short j = sgitg; j < Q; j += nsg) {
@@ -2194,10 +2194,8 @@ kernel void kernel_flash_attn_ext_f16(
21942194
}
21952195

21962196
// zero out lo
2197-
for (short j = 0; j < Q8; ++j) {
2198-
for (short i = 0; i < D8; ++i) {
2199-
lo[j][i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
2200-
}
2197+
for (short i = 0; i < D8; ++i) {
2198+
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
22012199
}
22022200

22032201
// zero out shared memory SH
@@ -2229,20 +2227,18 @@ kernel void kernel_flash_attn_ext_f16(
22292227
const short rv3 = ne03/ne23;
22302228

22312229
// k indices
2232-
const short ik2 = iq2 / rk2;
2233-
const short ik3 = iq3 / rk3;
2230+
const short ik2 = iq2/rk2;
2231+
const short ik3 = iq3/rk3;
22342232

22352233
// v indices
2236-
const short iv2 = iq2 / rv2;
2237-
const short iv3 = iq3 / rv3;
2234+
const short iv2 = iq2/rv2;
2235+
const short iv3 = iq3/rv3;
22382236

22392237
// load the queries from shared memory into local memory
2240-
simdgroup_half8x8 mq[Q8][D8];
2238+
simdgroup_half8x8 mq[D8];
22412239

2242-
for (short j = 0; j < Q8; ++j) {
2243-
for (short i = 0; i < D8; ++i) {
2244-
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
2245-
}
2240+
for (short i = 0; i < D8; ++i) {
2241+
simdgroup_load(mq[i], sq + i*8, T);
22462242
}
22472243

22482244
// pointer to the mask
@@ -2262,38 +2258,31 @@ kernel void kernel_flash_attn_ext_f16(
22622258
// Q*K^T
22632259
{
22642260
for (short cc = 0; cc < C/8; ++cc) {
2265-
simdgroup_float8x8 mqk[Q8];
2266-
for (short j = 0; j < Q8; ++j) {
2267-
mqk[j] = make_filled_simdgroup_matrix<float, 8>(0.h);
2268-
}
2261+
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
22692262

22702263
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
22712264

22722265
for (short i = 0; i < D8; ++i) {
22732266
simdgroup_half8x8 mk;
22742267
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
22752268

2276-
for (short j = 0; j < Q8; ++j) {
2277-
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
2278-
}
2269+
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
22792270
}
22802271

22812272
// mqk = mqk*scale + mask
2282-
for (short j = 0; j < Q8; ++j) {
2283-
simdgroup_half8x8 mm;
2284-
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
2285-
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
2273+
simdgroup_half8x8 mm;
2274+
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2275+
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
22862276

2287-
simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false);
2288-
}
2277+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
22892278
}
22902279
}
22912280

22922281
// used to detect blocks full of -INF
22932282
float smax = -INFINITY;
22942283

22952284
// online softmax
2296-
if (C == 32) {
2285+
{
22972286
float ms[Q];
22982287

22992288
for (short j = 0; j < Q; ++j) {
@@ -2314,45 +2303,6 @@ kernel void kernel_flash_attn_ext_f16(
23142303
ss[j*TF + p] = vs;
23152304
}
23162305

2317-
// create a QxQ diagonal matrix for rescaling the output
2318-
if (tiisg < Q) {
2319-
ss[tiisg*TF + C + tiisg] = ms[tiisg];
2320-
}
2321-
} else {
2322-
float ms[Q];
2323-
2324-
for (short j = 0; j < Q; ++j) {
2325-
const float m = M[j];
2326-
2327-
for (short p = tiisg; p < C; p += NW) {
2328-
const float s = ss[j*TF + p];
2329-
2330-
smax = max(smax, s);
2331-
M[j] = max(M[j], s);
2332-
}
2333-
2334-
smax = simd_max(smax);
2335-
M[j] = simd_max(M[j]);
2336-
2337-
ms[j] = exp(m - M[j]);
2338-
2339-
// local sum
2340-
float ls = 0.0h;
2341-
2342-
for (short p = tiisg; p < C; p += NW) {
2343-
const float s = ss[j*TF + p];
2344-
2345-
const float vs = exp(s - M[j]);
2346-
2347-
ls += vs;
2348-
2349-
// the P matrix from the paper (Q rows, C columns)
2350-
ss[j*TF + p] = vs;
2351-
}
2352-
2353-
S[j] = S[j]*ms[j] + simd_sum(ls);
2354-
}
2355-
23562306
// create a QxQ diagonal matrix for rescaling the output
23572307
if (tiisg < Q) {
23582308
ss[tiisg*TF + C + tiisg] = ms[tiisg];
@@ -2365,12 +2315,12 @@ kernel void kernel_flash_attn_ext_f16(
23652315
}
23662316

23672317
// O = diag(ms)*O
2368-
for (short j = 0; j < Q8; ++j) {
2318+
{
23692319
simdgroup_float8x8 mm;
2370-
simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false);
2320+
simdgroup_load(mm, ss + C, TF, 0, false);
23712321

23722322
for (short i = 0; i < D8; ++i) {
2373-
simdgroup_multiply(lo[j][i], mm, lo[j][i]);
2323+
simdgroup_multiply(lo[i], mm, lo[i]);
23742324
}
23752325
}
23762326

@@ -2383,12 +2333,10 @@ kernel void kernel_flash_attn_ext_f16(
23832333
simdgroup_half8x8 mk;
23842334
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
23852335

2386-
for (short j = 0; j < Q8; ++j) {
2387-
simdgroup_float8x8 mv;
2388-
simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false);
2336+
simdgroup_float8x8 mv;
2337+
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
23892338

2390-
simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
2391-
}
2339+
simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
23922340
}
23932341
}
23942342
}
@@ -2412,10 +2360,8 @@ kernel void kernel_flash_attn_ext_f16(
24122360

24132361
// each simdgroup stores its output to shared memory, reusing sq
24142362
if (sgitg == sg) {
2415-
for (short j = 0; j < Q8; ++j) {
2416-
for (short i = 0; i < D8; ++i) {
2417-
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
2418-
}
2363+
for (short i = 0; i < D8; ++i) {
2364+
simdgroup_store(lo[i], sq + i*8, T, 0, false);
24192365
}
24202366
}
24212367

@@ -2447,30 +2393,28 @@ kernel void kernel_flash_attn_ext_f16(
24472393
}
24482394

24492395
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2450-
for (short j = 0; j < Q8; ++j) {
2396+
{
24512397
simdgroup_half8x8 t;
24522398
simdgroup_float8x8 ms0;
24532399
simdgroup_float8x8 ms1;
24542400

2455-
simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false);
2456-
simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false);
2401+
simdgroup_load(ms0, ss + C, TF, 0, false);
2402+
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
24572403

24582404
for (short i = 0; i < D8; ++i) {
2459-
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
2405+
simdgroup_load (t, sq + i*8, T, 0, false);
24602406
simdgroup_multiply(t, ms1, t);
24612407

2462-
simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t);
2408+
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
24632409
}
24642410
}
24652411
}
24662412
}
24672413

24682414
// store result to shared memory (reuse sq)
24692415
if (sgitg == 0) {
2470-
for (short j = 0; j < Q8; ++j) {
2471-
for (short i = 0; i < D8; ++i) {
2472-
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
2473-
}
2416+
for (short i = 0; i < D8; ++i) {
2417+
simdgroup_store(lo[i], sq + i*8, T, 0, false);
24742418
}
24752419
}
24762420

@@ -2488,14 +2432,14 @@ kernel void kernel_flash_attn_ext_f16(
24882432
}
24892433
}
24902434

2491-
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
2492-
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
2493-
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>;
2494-
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>;
2495-
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
2496-
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
2435+
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
2436+
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
2437+
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2438+
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2439+
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2440+
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
24972441

2498-
template<int64_t D, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
2442+
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
24992443
kernel void kernel_flash_attn_ext_vec_f16(
25002444
device const char * q,
25012445
device const char * k,
@@ -2539,7 +2483,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
25392483

25402484
const short D4 = D/4;
25412485
const short NW = N_SIMDWIDTH;
2542-
const short SH = (C + 1); // shared memory per simdgroup in (half)
2486+
const short SH = (C + Q); // shared memory per simdgroup in (half)
25432487

25442488
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
25452489

@@ -2763,8 +2707,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
27632707
}
27642708
}
27652709

2766-
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>;
2767-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>;
2710+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2711+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
27682712

27692713
kernel void kernel_cpy_f16_f16(
27702714
device const half * src0,

0 commit comments

Comments
 (0)