Skip to content

Commit b57af0c

Browse files
committed
metal : initial FA vec kernel
1 parent f8d709f commit b57af0c

File tree

2 files changed

+78
-121
lines changed

2 files changed

+78
-121
lines changed

ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2603,7 +2603,7 @@ static enum ggml_status ggml_metal_graph_compute(
26032603

26042604
// simdgroups per threadgroup (a.k.a. warps)
26052605
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
2606-
const int64_t nsgt = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2606+
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
26072607

26082608
int64_t nsg = 1;
26092609
while (nsg <= nsgt) {

ggml-metal.metal

Lines changed: 77 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,7 +2459,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f
24592459

24602460
#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
24612461

2462-
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
2462+
template<int64_t D, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
24632463
kernel void kernel_flash_attn_ext_vec_f16(
24642464
device const char * q,
24652465
device const char * k,
@@ -2499,12 +2499,12 @@ kernel void kernel_flash_attn_ext_vec_f16(
24992499

25002500
const short iq3 = tgpig[2];
25012501
const short iq2 = tgpig[1];
2502-
const short iq1 = tgpig[0]*Q;
2502+
const short iq1 = tgpig[0];
25032503

25042504
const short D4 = D/4;
25052505
const short D8 = D/8;
25062506
const short NW = N_SIMDWIDTH;
2507-
const short SH = (C + Q); // shared memory per simdgroup in (half)
2507+
const short SH = (C + 1); // shared memory per simdgroup in (half)
25082508

25092509
const short T = D + nsg*SH; // shared memory size per query in (half)
25102510
const short T4 = T/4; // shared memory size per query in (half4)
@@ -2513,43 +2513,37 @@ kernel void kernel_flash_attn_ext_vec_f16(
25132513
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
25142514
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
25152515
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4
2516-
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
2516+
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
25172517

25182518
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2519-
half4 lo[Q][D4/NW];
2519+
half4 lo[D4/NW];
25202520

25212521
// load heads from Q to shared memory
2522-
for (short j = sgitg; j < Q; j += nsg) {
2523-
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
2522+
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
25242523

2525-
for (short i = tiisg; i < D4; i += NW) {
2526-
if (iq1 + j < ne01) {
2527-
sq4[j*T4 + i] = (half4) q4[i];
2528-
} else {
2529-
sq4[j*T4 + i] = 0.0h;
2530-
}
2524+
for (short i = tiisg; i < D4; i += NW) {
2525+
if (iq1 < ne01) {
2526+
sq4[i] = (half4) q4[i];
2527+
} else {
2528+
sq4[i] = 0.0h;
25312529
}
25322530
}
25332531

25342532
// zero out lo
2535-
for (short j = 0; j < Q; ++j) {
2536-
for (short i = tiisg; i < D4; i += NW) {
2537-
lo[j][i/NW] = 0.0h;
2538-
}
2533+
for (short i = tiisg; i < D4; i += NW) {
2534+
lo[i/NW] = 0.0h;
25392535
}
25402536

25412537
// zero out shared memory SH
2542-
for (short j = 0; j < Q; ++j) {
2543-
for (short i = tiisg; i < SH/4; i += NW) {
2544-
ss4[j*T4 + i] = 0.0h;
2545-
}
2538+
for (short i = tiisg; i < SH/4; i += NW) {
2539+
ss4[i] = 0.0h;
25462540
}
25472541

25482542
threadgroup_barrier(mem_flags::mem_threadgroup);
25492543

25502544
{
2551-
half S[Q] = { [0 ... Q-1] = 0.0h };
2552-
half M[Q] = { [0 ... Q-1] = -HALF_MAX_HALF };
2545+
half S = { 0.0h };
2546+
half M = { -HALF_MAX_HALF };
25532547

25542548
// assume K and V are same shape
25552549
const short ne22 = ne12;
@@ -2575,21 +2569,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
25752569
const short iv3 = iq3 / rv3;
25762570

25772571
// load the queries from shared memory into local memory
2578-
half4 mq[Q][D4];
2572+
half4 mq[D4];
25792573

2580-
for (short j = 0; j < Q; ++j) {
2581-
for (short ii = 0; ii < D4; ii += NW) {
2582-
short i = ii + tiisg;
2583-
mq[j][i] = sq4[j*T4 + i];
2584-
}
2574+
for (short ii = 0; ii < D4; ii += NW) {
2575+
short i = ii + tiisg;
2576+
mq[i] = sq4[i];
25852577
}
25862578

25872579
// pointer to the mask
25882580
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
25892581

2590-
// prepare diagonal scale matrix
2591-
half mscale(scale);
2592-
25932582
// loop over the KV cache
25942583
// each simdgroup handles blocks of Q rows and C columns
25952584
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@@ -2600,8 +2589,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
26002589

26012590
// Q*K^T
26022591
{
2592+
#pragma unroll
26032593
for (short cc = 0; cc < C/4; ++cc) {
2604-
half4 mqk[Q] = { [0 ... Q-1] = 0.0h };
2594+
half4 mqk = { 0.0h };
26052595

26062596
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
26072597

@@ -2615,142 +2605,110 @@ kernel void kernel_flash_attn_ext_vec_f16(
26152605
mk[2] = pk4[i + 2*(nb11/8)];
26162606
mk[3] = pk4[i + 3*(nb11/8)];
26172607

2618-
for (short j = 0; j < Q; ++j) {
2619-
mqk[j] += mq[j][i] * mk;
2620-
}
2608+
mqk += mq[i] * mk;
26212609
}
26222610

26232611
// reduce the results from the threads in the simdgroup
2624-
for (short j = 0; j < Q; ++j) {
2625-
mqk[j] += simd_shuffle_down(mqk[j], 16);
2626-
mqk[j] += simd_shuffle_down(mqk[j], 8);
2627-
mqk[j] += simd_shuffle_down(mqk[j], 4);
2628-
mqk[j] += simd_shuffle_down(mqk[j], 2);
2629-
mqk[j] += simd_shuffle_down(mqk[j], 1);
2630-
}
2612+
mqk += simd_shuffle_down(mqk, 16);
2613+
mqk += simd_shuffle_down(mqk, 8);
2614+
mqk += simd_shuffle_down(mqk, 4);
2615+
mqk += simd_shuffle_down(mqk, 2);
2616+
mqk += simd_shuffle_down(mqk, 1);
26312617

26322618
// mqk = mqk*scale + mask
26332619
if (tiisg == 0) {
2634-
for (short j = 0; j < Q; ++j) {
2635-
half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
2636-
mqk[j] = mqk[j]*mscale + mm;
2620+
half4 mm = mp4[ic/4 + cc];
2621+
mqk = mqk*scale + mm;
26372622

2638-
ss4[j*T4 + cc] = mqk[j];
2639-
}
2623+
ss4[cc] = mqk;
26402624
}
26412625
}
26422626
}
26432627

26442628
// online softmax
2645-
half ms[Q];
2646-
2647-
for (short j = 0; j < Q; ++j) {
2629+
{
26482630
const short p = tiisg;
26492631

2650-
const half m = M[j];
2651-
const half s = ss[j*T + p];
2632+
const half m = M;
2633+
const half s = ss[p];
26522634

2653-
M[j] = simd_max(max(M[j], s));
2635+
M = simd_max(max(M, s));
26542636

2655-
ms[j] = exp(m - M[j]);
2656-
const half vs = exp(s - M[j]);
2637+
const half ms = exp(m - M);
2638+
const half vs = exp(s - M);
26572639

2658-
S[j] = S[j]*ms[j] + simd_sum(vs);
2640+
S = S*ms + simd_sum(vs);
26592641

26602642
// the P matrix from the paper (Q rows, C columns)
2661-
ss[j*T + p] = vs;
2662-
}
2663-
2664-
// create a QxQ diagonal matrix for rescaling the output
2665-
if (tiisg < Q) {
2666-
ss[tiisg*T + C + tiisg] = ms[tiisg];
2667-
}
2668-
2669-
// O = diag(ms)*O
2670-
for (short j = 0; j < Q; ++j) {
2671-
half mm(ss[j*T + C + j]);
2643+
ss[p] = vs;
26722644

2645+
// O = diag(ms)*O
26732646
#pragma unroll
26742647
for (short ii = 0; ii < D4; ii += NW) {
26752648
const short i = ii + tiisg;
2676-
lo[j][i/NW] = lo[j][i/NW]*mm;
2649+
lo[i/NW] *= ms;
26772650
}
26782651
}
26792652

26802653
// O = O + (Q*K^T)*V
26812654
{
26822655
#pragma unroll
2683-
for (short cc = 0; cc < C; ++cc) {
2684-
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
2656+
for (short cc = 0; cc < C/4; ++cc) {
2657+
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
26852658

26862659
#pragma unroll
26872660
for (short ii = 0; ii < D4; ii += NW) {
2688-
short i = ii + tiisg;
2689-
for (short j = 0; j < Q; ++j) {
2690-
lo[j][i/NW] += pv4[i]*ss[j*T + cc];
2691-
}
2661+
const short i = ii + tiisg;
2662+
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
2663+
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
2664+
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
2665+
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
26922666
}
26932667
}
26942668
}
2669+
26952670
}
26962671

26972672
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
2698-
for (short j = 0; j < Q; ++j) {
2699-
if (tiisg == 0) {
2700-
ss[j*T + 0] = S[j];
2701-
ss[j*T + 1] = M[j];
2702-
}
2673+
if (tiisg == 0) {
2674+
ss[0] = S;
2675+
ss[1] = M;
27032676
}
27042677
}
27052678

27062679
// store results to shared memory
2707-
for (short j = 0; j < Q; ++j) {
2708-
for (short ii = 0; ii < D4; ii += NW) {
2709-
short i = ii + tiisg;
2710-
sr4[i] = lo[j][ii/NW];
2711-
}
2680+
for (short ii = 0; ii < D4; ii += NW) {
2681+
short i = ii + tiisg;
2682+
sr4[i] = lo[ii/NW];
27122683
}
27132684

27142685
threadgroup_barrier(mem_flags::mem_threadgroup);
27152686

27162687
// parallel reduce
27172688
for (short r = nsg/2; r > 0; r >>= 1) {
27182689
if (sgitg < r) {
2719-
if (tiisg == 0) {
2720-
for (short j = 0; j < Q; ++j) {
2721-
const half S0 = ss[j*T + 0];
2722-
const half S1 = ss[j*T + r*SH + 0];
2690+
const half S0 = ss[ 0];
2691+
const half S1 = ss[r*SH + 0];
27232692

2724-
const half M0 = ss[j*T + 1];
2725-
const half M1 = ss[j*T + r*SH + 1];
2693+
const half M0 = ss[ 1];
2694+
const half M1 = ss[r*SH + 1];
27262695

2727-
const half M = max(M0, M1);
2696+
const half M = max(M0, M1);
27282697

2729-
const half ms0 = exp(M0 - M);
2730-
const half ms1 = exp(M1 - M);
2698+
const half ms0 = exp(M0 - M);
2699+
const half ms1 = exp(M1 - M);
27312700

2732-
const half S = S0*ms0 + S1*ms1;
2733-
2734-
ss[j*T + 0] = S;
2735-
ss[j*T + 1] = M;
2701+
const half S = S0*ms0 + S1*ms1;
27362702

2737-
ss[j*T + C + j ] = ms0;
2738-
ss[j*T + C + j + r*SH] = ms1;
2739-
}
2703+
if (tiisg == 0) {
2704+
ss[0] = S;
2705+
ss[1] = M;
27402706
}
2741-
}
2742-
2743-
threadgroup_barrier(mem_flags::mem_threadgroup);
27442707

2745-
if (sgitg < r) {
2746-
for (short j = 0; j < Q; ++j) {
2747-
const half ms0 = ss[j*T + C + j];
2748-
const half ms1 = ss[j*T + C + j + r*SH];
2749-
2750-
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2751-
for (short i = tiisg; i < D4; i += NW) {
2752-
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
2753-
}
2708+
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2709+
for (short ii = 0; ii < D4; ii += NW) {
2710+
short i = ii + tiisg;
2711+
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
27542712
}
27552713
}
27562714

@@ -2761,18 +2719,17 @@ kernel void kernel_flash_attn_ext_vec_f16(
27612719

27622720
// final rescale with 1/S and store to global memory
27632721
if (sgitg == 0) {
2764-
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
2765-
const half S = ss[j*T + 0];
2722+
const half S = ss[0];
27662723

2767-
for (short i = tiisg; i < D4; i += NW) {
2768-
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[i]/S;
2769-
}
2724+
for (short ii = 0; ii < D4; ii += NW) {
2725+
short i = ii + tiisg;
2726+
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
27702727
}
27712728
}
27722729
}
27732730

2774-
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>;
2775-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;
2731+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>;
2732+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>;
27762733

27772734
kernel void kernel_cpy_f16_f16(
27782735
device const half * src0,

0 commit comments

Comments
 (0)