Skip to content

Commit e51778d

Browse files
committed
metal : switch to parallel reduce
1 parent 5733b00 commit e51778d

File tree

2 files changed

+119
-93
lines changed

2 files changed

+119
-93
lines changed

ggml-metal.m

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2615,13 +2615,23 @@ static enum ggml_status ggml_metal_graph_compute(
26152615

26162616
// simdgroups per threadgroup (a.k.a. warps)
26172617
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
2618-
const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2618+
//const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2619+
const int64_t nsg = 8;
26192620

2620-
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2621+
// require power of 2
2622+
//{
2623+
// int64_t nsgm = 1;
2624+
// while (nsgm < nsg) {
2625+
// nsgm *= 2;
2626+
// }
2627+
// GGML_ASSERT(nsg == nsgm);
2628+
//}
2629+
2630+
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
26212631

26222632
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
26232633
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2624-
[encoder setThreadgroupMemoryLength:smem atIndex:0];
2634+
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
26252635

26262636
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
26272637
}

ggml-metal.metal

Lines changed: 106 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2457,6 +2457,8 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f
24572457
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
24582458
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
24592459

2460+
#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
2461+
24602462
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
24612463
kernel void kernel_flash_attn_ext_vec_f16(
24622464
device const char * q,
@@ -2500,6 +2502,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
25002502
const short iq1 = tgpig[0]*Q;
25012503

25022504
const short D4 = D/4;
2505+
const short D8 = D/8;
25032506
const short NW = N_SIMDWIDTH;
25042507
const short SH = (C + Q); // shared memory per simdgroup in (half)
25052508

@@ -2510,6 +2513,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
25102513
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
25112514
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
25122515
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4
2516+
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
25132517

25142518
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
25152519
half4 lo[Q][D4];
@@ -2545,7 +2549,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
25452549

25462550
{
25472551
half S[Q] = { [0 ... Q-1] = 0.0h };
2548-
half M[Q] = { [0 ... Q-1] = -INFINITY };
2552+
half M[Q] = { [0 ... Q-1] = -HALF_MAX_HALF };
25492553

25502554
// assume K and V are same shape
25512555
const short ne22 = ne12;
@@ -2571,21 +2575,21 @@ kernel void kernel_flash_attn_ext_vec_f16(
25712575
const short iv3 = iq3 / rv3;
25722576

25732577
// load the queries from shared memory into local memory
2574-
half4 mq[Q][D4];
2578+
simdgroup_half8x8 mq[Q][D8];
25752579

25762580
for (short j = 0; j < Q; ++j) {
2577-
for (short i = tiisg; i < D4; i += NW) {
2578-
//simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
2579-
mq[j][i] = sq4[j*T4 + i];
2581+
for (short i = 0; i < D8; ++i) {
2582+
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
25802583
}
25812584
}
25822585

25832586
// pointer to the mask
2584-
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2587+
//device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2588+
device const half * mp = (device const half *) (mask + iq1*nb31);
25852589

25862590
// prepare diagonal scale matrix
2587-
//simdgroup_half8x8 mscale(scale);
2588-
half mscale(scale);
2591+
simdgroup_half8x8 mscale(scale);
2592+
//half mscale(scale);
25892593

25902594
// loop over the KV cache
25912595
// each simdgroup handles blocks of Q rows and C columns
@@ -2595,55 +2599,83 @@ kernel void kernel_flash_attn_ext_vec_f16(
25952599
break;
25962600
}
25972601

2602+
// Q*K^T
2603+
//{
2604+
// for (short cc = 0; cc < C/4; ++cc) {
2605+
// half4 mqk[Q];
2606+
// for (short j = 0; j < Q; ++j) {
2607+
// mqk[j] = 0.0h;
2608+
// }
2609+
2610+
// device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
2611+
2612+
// for (short i = tiisg; i < D4; i += NW) {
2613+
// half4x4 mk;
2614+
// mk[0] = pk4[i + 0*(nb11/8)];
2615+
// mk[1] = pk4[i + 1*(nb11/8)];
2616+
// mk[2] = pk4[i + 2*(nb11/8)];
2617+
// mk[3] = pk4[i + 3*(nb11/8)];
2618+
2619+
// for (short j = 0; j < Q; ++j) {
2620+
// mqk[j] += mq[j][i] * mk;
2621+
// }
2622+
// }
2623+
2624+
// // reduce the results from the threads in the simdgroup
2625+
// simdgroup_barrier(mem_flags::mem_none);
2626+
2627+
// for (short i = NW/2; i > 0; i /= 2) {
2628+
// if (tiisg < i) {
2629+
// for (short j = 0; j < Q; ++j) {
2630+
// mqk[j] += simd_shuffle_down(mqk[j], i);
2631+
// }
2632+
// }
2633+
2634+
// simdgroup_barrier(mem_flags::mem_none);
2635+
// }
2636+
2637+
// // mqk = mqk*scale + mask
2638+
// if (tiisg == 0) {
2639+
// for (short j = 0; j < Q; ++j) {
2640+
// half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
2641+
// mqk[j] = mqk[j]*mscale + mm;
2642+
2643+
// ss4[j*T4 + cc] = mqk[j];
2644+
// }
2645+
// }
2646+
// }
2647+
//}
2648+
25982649
// Q*K^T
25992650
{
2600-
for (short cc = 0; cc < C/4; ++cc) {
2601-
half4 mqk[Q];
2651+
for (short cc = 0; cc < C/8; ++cc) {
2652+
simdgroup_half8x8 mqk[Q];
26022653
for (short j = 0; j < Q; ++j) {
2603-
mqk[j] = 0.0h;
2654+
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
26042655
}
26052656

2606-
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
2657+
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
26072658

2608-
for (short i = tiisg; i < D4; i += NW) {
2609-
half4x4 mk;
2610-
mk[0] = pk4[i + 0*(nb11/8)];
2611-
mk[1] = pk4[i + 1*(nb11/8)];
2612-
mk[2] = pk4[i + 2*(nb11/8)];
2613-
mk[3] = pk4[i + 3*(nb11/8)];
2659+
for (short i = 0; i < D8; ++i) {
2660+
simdgroup_half8x8 mk;
2661+
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
26142662

26152663
for (short j = 0; j < Q; ++j) {
2616-
mqk[j] += mq[j][i] * mk;
2617-
}
2618-
}
2619-
2620-
// reduce the results from the threads in the simdgroup
2621-
simdgroup_barrier(mem_flags::mem_none);
2622-
2623-
for (short i = NW/2; i > 0; i /= 2) {
2624-
if (tiisg < i) {
2625-
for (short j = 0; j < Q; ++j) {
2626-
mqk[j] += simd_shuffle_down(mqk[j], i);
2627-
}
2664+
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
26282665
}
2629-
2630-
simdgroup_barrier(mem_flags::mem_none);
26312666
}
26322667

26332668
// mqk = mqk*scale + mask
2634-
if (tiisg == 0) {
2635-
for (short j = 0; j < Q; ++j) {
2636-
half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
2637-
mqk[j] = mqk[j]*mscale + mm;
2669+
for (short j = 0; j < Q; ++j) {
2670+
simdgroup_half8x8 mm;
2671+
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
2672+
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
26382673

2639-
ss4[j*T4 + cc] = mqk[j];
2640-
}
2674+
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
26412675
}
26422676
}
26432677
}
26442678

2645-
simdgroup_barrier(mem_flags::mem_threadgroup);
2646-
26472679
// online softmax
26482680
half ms[Q];
26492681

@@ -2655,8 +2687,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
26552687

26562688
M[j] = simd_max(max(M[j], s));
26572689

2658-
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
2659-
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
2690+
ms[j] = exp(m - M[j]);
2691+
const half vs = exp(s - M[j]);
26602692

26612693
S[j] = S[j]*ms[j] + simd_sum(vs);
26622694

@@ -2706,75 +2738,59 @@ kernel void kernel_flash_attn_ext_vec_f16(
27062738
}
27072739
}
27082740

2709-
// reduce the warps sequentially
2710-
for (short sg = 1; sg < nsg; ++sg) {
2711-
half S = { 0.0h };
2712-
half M = { -INFINITY };
2713-
2714-
threadgroup_barrier(mem_flags::mem_threadgroup);
2741+
threadgroup_barrier(mem_flags::mem_threadgroup);
27152742

2716-
// each simdgroup stores its output to shared memory, reusing sq
2717-
if (sgitg == sg) {
2718-
for (short j = 0; j < Q; ++j) {
2719-
for (short i = tiisg; i < D4; i += NW) {
2720-
//simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
2721-
sq4[j*T4 + i] = lo[j][i];
2722-
}
2723-
}
2743+
// store results to shared memory
2744+
for (short j = 0; j < Q; ++j) {
2745+
for (short i = tiisg; i < D4; i += NW) {
2746+
sr4[i] = lo[j][i];
27242747
}
2748+
}
27252749

2726-
threadgroup_barrier(mem_flags::mem_threadgroup);
2727-
2728-
// the first simdgroup accumulates the results from the other simdgroups
2729-
if (sgitg == 0) {
2730-
for (short j = 0; j < Q; ++j) {
2731-
const half S0 = ss[j*T + 0];
2732-
const half S1 = ss[j*T + sg*SH + 0];
2750+
// parallel reduce
2751+
for (short r = nsg/2; r > 0; r >>= 1) {
2752+
if (sgitg < r) {
2753+
if (tiisg == 0) {
2754+
for (short j = 0; j < Q; ++j) {
2755+
const half S0 = ss[j*T + 0];
2756+
const half S1 = ss[j*T + r*SH + 0];
27332757

2734-
const half M0 = ss[j*T + 1];
2735-
const half M1 = ss[j*T + sg*SH + 1];
2758+
const half M0 = ss[j*T + 1];
2759+
const half M1 = ss[j*T + r*SH + 1];
27362760

2737-
M = max(M0, M1);
2761+
const half M = max(M0, M1);
27382762

2739-
const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M);
2740-
const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M);
2763+
const half ms0 = exp(M0 - M);
2764+
const half ms1 = exp(M1 - M);
27412765

2742-
S = S0*ms0 + S1*ms1;
2766+
const half S = S0*ms0 + S1*ms1;
27432767

2744-
if (tiisg == 0) {
27452768
ss[j*T + 0] = S;
27462769
ss[j*T + 1] = M;
27472770

2748-
ss[j*T + C + j ] = ms0;
2749-
ss[j*T + C + j + sg*SH] = ms1;
2771+
ss[j*T + C + j ] = ms0;
2772+
ss[j*T + C + j + r*SH] = ms1;
27502773
}
27512774
}
2775+
}
27522776

2753-
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2777+
threadgroup_barrier(mem_flags::mem_threadgroup);
2778+
2779+
if (sgitg < r) {
27542780
for (short j = 0; j < Q; ++j) {
2755-
for (short i = tiisg; i < D4; i += NW) {
2756-
half4 t = sq4[j*T4 + i];
2757-
half ms0 = ss[j*T + C + j];
2758-
half ms1 = ss[j*T + C + j + sg*SH];
2781+
const half ms0 = ss[j*T + C + j];
2782+
const half ms1 = ss[j*T + C + j + r*SH];
27592783

2760-
lo[j][i] = lo[j][i]*ms0 + t*ms1;
2784+
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2785+
for (short i = tiisg; i < D4; i += NW) {
2786+
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
27612787
}
27622788
}
27632789
}
2764-
}
27652790

2766-
// store result to shared memory (reuse sq)
2767-
if (sgitg == 0) {
2768-
for (short j = 0; j < Q; ++j) {
2769-
for (short i = tiisg; i < D4; i += NW) {
2770-
//simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
2771-
sq4[j*T4 + i] = lo[j][i];
2772-
}
2773-
}
2791+
threadgroup_barrier(mem_flags::mem_threadgroup);
27742792
}
27752793

2776-
threadgroup_barrier(mem_flags::mem_threadgroup);
2777-
27782794
device float4 * dst4 = (device float4 *) dst;
27792795

27802796
// final rescale with 1/S and store to global memory
@@ -2783,7 +2799,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
27832799
const half S = ss[j*T + 0];
27842800

27852801
for (short i = tiisg; i < D4; i += NW) {
2786-
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
2802+
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[i]/S;
27872803
}
27882804
}
27892805
}

0 commit comments

Comments
 (0)