Skip to content

Commit 5733b00

Browse files
committed
metal : opt
1 parent 8d2a61f commit 5733b00

File tree

1 file changed

+28
-74
lines changed

1 file changed

+28
-74
lines changed

ggml-metal.metal

Lines changed: 28 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,7 +2581,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
25812581
}
25822582

25832583
// pointer to the mask
2584-
device const half * mp = (device const half *) (mask + iq1*nb31);
2584+
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
25852585

25862586
// prepare diagonal scale matrix
25872587
//simdgroup_half8x8 mscale(scale);
@@ -2597,23 +2597,23 @@ kernel void kernel_flash_attn_ext_vec_f16(
25972597

25982598
// Q*K^T
25992599
{
2600-
for (short cc = 0; cc < C; ++cc) {
2601-
half mqk[Q];
2600+
for (short cc = 0; cc < C/4; ++cc) {
2601+
half4 mqk[Q];
26022602
for (short j = 0; j < Q; ++j) {
26032603
mqk[j] = 0.0h;
26042604
}
26052605

2606-
//device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
2607-
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13));
2606+
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
26082607

26092608
for (short i = tiisg; i < D4; i += NW) {
2610-
//simdgroup_half8x8 mk;
2611-
//simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
2612-
half4 mk = pk4[i];
2609+
half4x4 mk;
2610+
mk[0] = pk4[i + 0*(nb11/8)];
2611+
mk[1] = pk4[i + 1*(nb11/8)];
2612+
mk[2] = pk4[i + 2*(nb11/8)];
2613+
mk[3] = pk4[i + 3*(nb11/8)];
26132614

26142615
for (short j = 0; j < Q; ++j) {
2615-
//simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
2616-
mqk[j] += dot(mq[j][i], mk);
2616+
mqk[j] += mq[j][i] * mk;
26172617
}
26182618
}
26192619

@@ -2633,85 +2633,40 @@ kernel void kernel_flash_attn_ext_vec_f16(
26332633
// mqk = mqk*scale + mask
26342634
if (tiisg == 0) {
26352635
for (short j = 0; j < Q; ++j) {
2636-
//simdgroup_half8x8 mm;
2637-
//simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
2638-
//simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
2639-
2640-
//simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
2641-
2642-
half mm = mp[j*(nb31/sizeof(half)) + ic + cc];
2636+
half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
26432637
mqk[j] = mqk[j]*mscale + mm;
26442638

2645-
ss[j*T + cc] = mqk[j];
2639+
ss4[j*T4 + cc] = mqk[j];
26462640
}
26472641
}
26482642
}
26492643
}
26502644

2651-
//threadgroup_barrier(mem_flags::mem_threadgroup);
2645+
simdgroup_barrier(mem_flags::mem_threadgroup);
26522646

26532647
// online softmax
2654-
if (C == 32) {
2655-
half ms[Q];
2656-
2657-
for (short j = 0; j < Q; ++j) {
2658-
const short p = tiisg;
2659-
2660-
const half m = M[j];
2661-
const half s = ss[j*T + p];
2662-
2663-
M[j] = simd_max(max(M[j], s));
2664-
2665-
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
2666-
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
2667-
2668-
S[j] = S[j]*ms[j] + simd_sum(vs);
2669-
2670-
// the P matrix from the paper (Q rows, C columns)
2671-
ss[j*T + p] = vs;
2672-
}
2673-
2674-
// create a QxQ diagonal matrix for rescaling the output
2675-
if (tiisg < Q) {
2676-
ss[tiisg*T + C + tiisg] = ms[tiisg];
2677-
}
2678-
} else {
2679-
half ms[Q];
2648+
half ms[Q];
26802649

2681-
for (short j = 0; j < Q; ++j) {
2682-
const half m = M[j];
2683-
2684-
for (short p = tiisg; p < C; p += NW) {
2685-
const half s = ss[j*T + p];
2686-
2687-
M[j] = max(M[j], s);
2688-
}
2689-
2690-
M[j] = simd_max(M[j]);
2691-
2692-
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
2693-
2694-
// local sum
2695-
half ls = 0.0h;
2650+
for (short j = 0; j < Q; ++j) {
2651+
const short p = tiisg;
26962652

2697-
for (short p = tiisg; p < C; p += NW) {
2698-
const half s = ss[j*T + p];
2653+
const half m = M[j];
2654+
const half s = ss[j*T + p];
26992655

2700-
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
2656+
M[j] = simd_max(max(M[j], s));
27012657

2702-
ls += vs;
2658+
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
2659+
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
27032660

2704-
// the P matrix from the paper (Q rows, C columns)
2705-
ss[j*T + p] = vs;
2706-
}
2661+
S[j] = S[j]*ms[j] + simd_sum(vs);
27072662

2708-
S[j] = S[j]*ms[j] + simd_sum(ls);
2709-
}
2663+
// the P matrix from the paper (Q rows, C columns)
2664+
ss[j*T + p] = vs;
2665+
}
27102666

2711-
// create a QxQ diagonal matrix for rescaling the output
2712-
if (tiisg < Q) {
2713-
ss[tiisg*T + C + tiisg] = ms[tiisg];
2714-
}
2667+
// create a QxQ diagonal matrix for rescaling the output
2668+
if (tiisg < Q) {
2669+
ss[tiisg*T + C + tiisg] = ms[tiisg];
27152670
}
27162671

27172672
//threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -2733,7 +2688,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
27332688
for (short cc = 0; cc < C; ++cc) {
27342689
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
27352690

2736-
half vsum[Q];
27372691
for (short i = tiisg; i < D4; i += NW) {
27382692
for (short j = 0; j < Q; ++j) {
27392693
lo[j][i] += pv4[i]*ss[j*T + cc];

0 commit comments

Comments
 (0)