Skip to content

Commit f8d709f

Browse files
committed
metal : simplify
1 parent c4dff1e commit f8d709f

File tree

2 files changed

+54
-87
lines changed

2 files changed

+54
-87
lines changed

ggml-metal.m

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2573,7 +2573,7 @@ static enum ggml_status ggml_metal_graph_compute(
25732573
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
25742574

25752575
// half8x8 kernel
2576-
if (ne01 > 1) {
2576+
if (ne01 > 1 || (ne00%128 != 0)) {
25772577
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
25782578
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
25792579

@@ -2603,8 +2603,13 @@ static enum ggml_status ggml_metal_graph_compute(
26032603

26042604
// simdgroups per threadgroup (a.k.a. warps)
26052605
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
2606-
//const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2607-
const int64_t nsg = 8;
2606+
const int64_t nsgt = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2607+
2608+
int64_t nsg = 1;
2609+
while (nsg <= nsgt) {
2610+
nsg *= 2;
2611+
}
2612+
nsg /= 2;
26082613

26092614
// require power of 2
26102615
//{

ggml-metal.metal

Lines changed: 46 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,21 +2575,20 @@ kernel void kernel_flash_attn_ext_vec_f16(
25752575
const short iv3 = iq3 / rv3;
25762576

25772577
// load the queries from shared memory into local memory
2578-
simdgroup_half8x8 mq[Q][D8];
2578+
half4 mq[Q][D4];
25792579

25802580
for (short j = 0; j < Q; ++j) {
2581-
for (short i = 0; i < D8; ++i) {
2582-
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
2581+
for (short ii = 0; ii < D4; ii += NW) {
2582+
short i = ii + tiisg;
2583+
mq[j][i] = sq4[j*T4 + i];
25832584
}
25842585
}
25852586

25862587
// pointer to the mask
2587-
//device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2588-
device const half * mp = (device const half *) (mask + iq1*nb31);
2588+
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
25892589

25902590
// prepare diagonal scale matrix
2591-
simdgroup_half8x8 mscale(scale);
2592-
//half mscale(scale);
2591+
half mscale(scale);
25932592

25942593
// loop over the KV cache
25952594
// each simdgroup handles blocks of Q rows and C columns
@@ -2599,79 +2598,45 @@ kernel void kernel_flash_attn_ext_vec_f16(
25992598
break;
26002599
}
26012600

2602-
// Q*K^T
2603-
//{
2604-
// for (short cc = 0; cc < C/4; ++cc) {
2605-
// half4 mqk[Q];
2606-
// for (short j = 0; j < Q; ++j) {
2607-
// mqk[j] = 0.0h;
2608-
// }
2609-
2610-
// device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
2611-
2612-
// for (short i = tiisg; i < D4; i += NW) {
2613-
// half4x4 mk;
2614-
// mk[0] = pk4[i + 0*(nb11/8)];
2615-
// mk[1] = pk4[i + 1*(nb11/8)];
2616-
// mk[2] = pk4[i + 2*(nb11/8)];
2617-
// mk[3] = pk4[i + 3*(nb11/8)];
2618-
2619-
// for (short j = 0; j < Q; ++j) {
2620-
// mqk[j] += mq[j][i] * mk;
2621-
// }
2622-
// }
2623-
2624-
// // reduce the results from the threads in the simdgroup
2625-
// simdgroup_barrier(mem_flags::mem_none);
2626-
2627-
// for (short i = NW/2; i > 0; i /= 2) {
2628-
// if (tiisg < i) {
2629-
// for (short j = 0; j < Q; ++j) {
2630-
// mqk[j] += simd_shuffle_down(mqk[j], i);
2631-
// }
2632-
// }
2633-
2634-
// simdgroup_barrier(mem_flags::mem_none);
2635-
// }
2636-
2637-
// // mqk = mqk*scale + mask
2638-
// if (tiisg == 0) {
2639-
// for (short j = 0; j < Q; ++j) {
2640-
// half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
2641-
// mqk[j] = mqk[j]*mscale + mm;
2642-
2643-
// ss4[j*T4 + cc] = mqk[j];
2644-
// }
2645-
// }
2646-
// }
2647-
//}
2648-
26492601
// Q*K^T
26502602
{
2651-
for (short cc = 0; cc < C/8; ++cc) {
2652-
simdgroup_half8x8 mqk[Q];
2653-
for (short j = 0; j < Q; ++j) {
2654-
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
2655-
}
2603+
for (short cc = 0; cc < C/4; ++cc) {
2604+
half4 mqk[Q] = { [0 ... Q-1] = 0.0h };
26562605

2657-
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
2606+
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
26582607

2659-
for (short i = 0; i < D8; ++i) {
2660-
simdgroup_half8x8 mk;
2661-
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
2608+
#pragma unroll
2609+
for (short ii = 0; ii < D4; ii += NW) {
2610+
const short i = ii + tiisg;
2611+
2612+
half4x4 mk;
2613+
mk[0] = pk4[i + 0*(nb11/8)];
2614+
mk[1] = pk4[i + 1*(nb11/8)];
2615+
mk[2] = pk4[i + 2*(nb11/8)];
2616+
mk[3] = pk4[i + 3*(nb11/8)];
26622617

26632618
for (short j = 0; j < Q; ++j) {
2664-
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
2619+
mqk[j] += mq[j][i] * mk;
26652620
}
26662621
}
26672622

2668-
// mqk = mqk*scale + mask
2623+
// reduce the results from the threads in the simdgroup
26692624
for (short j = 0; j < Q; ++j) {
2670-
simdgroup_half8x8 mm;
2671-
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
2672-
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
2625+
mqk[j] += simd_shuffle_down(mqk[j], 16);
2626+
mqk[j] += simd_shuffle_down(mqk[j], 8);
2627+
mqk[j] += simd_shuffle_down(mqk[j], 4);
2628+
mqk[j] += simd_shuffle_down(mqk[j], 2);
2629+
mqk[j] += simd_shuffle_down(mqk[j], 1);
2630+
}
26732631

2674-
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
2632+
// mqk = mqk*scale + mask
2633+
if (tiisg == 0) {
2634+
for (short j = 0; j < Q; ++j) {
2635+
half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
2636+
mqk[j] = mqk[j]*mscale + mm;
2637+
2638+
ss4[j*T4 + cc] = mqk[j];
2639+
}
26752640
}
26762641
}
26772642
}
@@ -2701,26 +2666,26 @@ kernel void kernel_flash_attn_ext_vec_f16(
27012666
ss[tiisg*T + C + tiisg] = ms[tiisg];
27022667
}
27032668

2704-
//threadgroup_barrier(mem_flags::mem_threadgroup);
2705-
27062669
// O = diag(ms)*O
27072670
for (short j = 0; j < Q; ++j) {
2708-
//simdgroup_half8x8 mm;
2709-
//simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
27102671
half mm(ss[j*T + C + j]);
27112672

2712-
for (short i = tiisg; i < D4; i += NW) {
2713-
//simdgroup_multiply(lo[j][i], mm, lo[j][i]);
2673+
#pragma unroll
2674+
for (short ii = 0; ii < D4; ii += NW) {
2675+
const short i = ii + tiisg;
27142676
lo[j][i/NW] = lo[j][i/NW]*mm;
27152677
}
27162678
}
27172679

27182680
// O = O + (Q*K^T)*V
27192681
{
2682+
#pragma unroll
27202683
for (short cc = 0; cc < C; ++cc) {
27212684
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
27222685

2723-
for (short i = tiisg; i < D4; i += NW) {
2686+
#pragma unroll
2687+
for (short ii = 0; ii < D4; ii += NW) {
2688+
short i = ii + tiisg;
27242689
for (short j = 0; j < Q; ++j) {
27252690
lo[j][i/NW] += pv4[i]*ss[j*T + cc];
27262691
}
@@ -2738,15 +2703,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
27382703
}
27392704
}
27402705

2741-
threadgroup_barrier(mem_flags::mem_threadgroup);
2742-
27432706
// store results to shared memory
27442707
for (short j = 0; j < Q; ++j) {
2745-
for (short i = tiisg; i < D4; i += NW) {
2746-
sr4[i] = lo[j][i/NW];
2708+
for (short ii = 0; ii < D4; ii += NW) {
2709+
short i = ii + tiisg;
2710+
sr4[i] = lo[j][ii/NW];
27472711
}
27482712
}
27492713

2714+
threadgroup_barrier(mem_flags::mem_threadgroup);
2715+
27502716
// parallel reduce
27512717
for (short r = nsg/2; r > 0; r >>= 1) {
27522718
if (sgitg < r) {
@@ -2805,10 +2771,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
28052771
}
28062772
}
28072773

2808-
template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 2, 32>;
2809-
template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 3, 32>;
2810-
template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 4, 32>;
2811-
template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 5, 32>;
28122774
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>;
28132775
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;
28142776

0 commit comments

Comments
 (0)