Skip to content

Commit 1f63e75

Browse files
authored
metal : use less stack memory in FA kernel (#14088)
* metal : use less stack memory in FA kernel ggml-ci * cont : fix BF16 variant
1 parent 40cbf57 commit 1f63e75

File tree

1 file changed

+54
-59
lines changed

1 file changed

+54
-59
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 54 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3333,8 +3333,6 @@ kernel void kernel_flash_attn_ext(
33333333

33343334
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
33353335
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336-
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
33383336
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
33393337

33403338
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
@@ -3548,20 +3546,20 @@ kernel void kernel_flash_attn_ext(
35483546

35493547
// O = diag(ms)*O
35503548
{
3551-
s8x8_t mm;
3552-
simdgroup_load(mm, ss + 2*C, TS, 0, false);
3549+
s8x8_t ms;
3550+
simdgroup_load(ms, ss + 2*C, TS, 0, false);
35533551

35543552
#pragma unroll(DV8)
35553553
for (short i = 0; i < DV8; ++i) {
3556-
simdgroup_multiply(lo[i], mm, lo[i]);
3554+
simdgroup_multiply(lo[i], ms, lo[i]);
35573555
}
35583556
}
35593557

35603558
// O = O + (Q*K^T)*V
35613559
{
35623560
for (short cc = 0; cc < C/8; ++cc) {
3563-
s8x8_t ms;
3564-
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
3561+
s8x8_t vs;
3562+
simdgroup_load(vs, ss + 8*cc, TS, 0, false);
35653563

35663564
if (is_same<vd4x4_t, v4x4_t>::value) {
35673565
// we can read directly from global memory
@@ -3572,7 +3570,7 @@ kernel void kernel_flash_attn_ext(
35723570
v8x8_t mv;
35733571
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
35743572

3575-
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3573+
simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
35763574
}
35773575
} else {
35783576
for (short ii = 0; ii < DV16; ii += 4) {
@@ -3593,10 +3591,10 @@ kernel void kernel_flash_attn_ext(
35933591
v8x8_t mv;
35943592

35953593
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3596-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3594+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
35973595

35983596
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3599-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3597+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
36003598
}
36013599
} else {
36023600
if (ii + tx < DV16) {
@@ -3611,10 +3609,10 @@ kernel void kernel_flash_attn_ext(
36113609
v8x8_t mv;
36123610

36133611
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3614-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3612+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
36153613

36163614
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3617-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3615+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
36183616
}
36193617
}
36203618
}
@@ -3624,83 +3622,80 @@ kernel void kernel_flash_attn_ext(
36243622
}
36253623

36263624
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3627-
for (short j = 0; j < Q; ++j) {
3628-
if (tiisg == 0) {
3629-
ss[j*TS + 0] = S[j];
3630-
ss[j*TS + 1] = M[j];
3631-
}
3625+
for (short j = tiisg; j < Q; j += NW) {
3626+
ss[j*TS + 0] = S[j];
3627+
ss[j*TS + 1] = M[j];
36323628
}
36333629
}
36343630

3635-
// reduce the warps sequentially
3636-
for (ushort sg = 1; sg < nsg; ++sg) {
3637-
threadgroup_barrier(mem_flags::mem_threadgroup);
3631+
threadgroup_barrier(mem_flags::mem_threadgroup);
36383632

3639-
// each simdgroup stores its output to shared memory, reusing sq
3640-
if (sgitg == sg) {
3641-
for (short i = 0; i < DV8; ++i) {
3642-
simdgroup_store(lo[i], so + i*8, DV, 0, false);
3643-
}
3633+
threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3634+
threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
3635+
3636+
// store result to shared memory in F32
3637+
if (sgitg == 0) {
3638+
for (short i = 0; i < DV8; ++i) {
3639+
//simdgroup_store(lo[i], so + i*8, DV, 0, false);
3640+
simdgroup_float8x8 t(1.0f);
3641+
simdgroup_multiply(t, lo[i], t);
3642+
simdgroup_store(t, so + i*8, DV, 0, false);
36443643
}
3644+
}
36453645

3646-
threadgroup_barrier(mem_flags::mem_threadgroup);
3646+
threadgroup_barrier(mem_flags::mem_threadgroup);
36473647

3648-
// the first simdgroup accumulates the results from the other simdgroups
3649-
if (sgitg == 0) {
3650-
for (short j = 0; j < Q; ++j) {
3651-
const float S0 = ss[j*TS + 0];
3652-
const float S1 = ss[j*TS + sg*SH + 0];
3648+
// reduce the warps sequentially
3649+
for (ushort sg = 1; sg < nsg; ++sg) {
3650+
if (sgitg == sg) {
3651+
for (short j = tiisg; j < Q; j += NW) {
3652+
const float S0 = ss[j*TS - 1*SH + 0];
3653+
const float S1 = ss[j*TS + 0];
36533654

3654-
const float M0 = ss[j*TS + 1];
3655-
const float M1 = ss[j*TS + sg*SH + 1];
3655+
const float M0 = ss[j*TS - 1*SH + 1];
3656+
const float M1 = ss[j*TS + 1];
36563657

36573658
const float M = max(M0, M1);
36583659

3659-
const float ms0 = exp(M0 - M);
3660-
const float ms1 = exp(M1 - M);
3660+
float ms0 = exp(M0 - M);
3661+
float ms1 = exp(M1 - M);
36613662

36623663
const float S = S0*ms0 + S1*ms1;
36633664

3664-
if (tiisg == 0) {
3665-
ss[j*TS + 0] = S;
3666-
ss[j*TS + 1] = M;
3665+
ss[j*TS + 0] = S;
3666+
ss[j*TS + 1] = M;
36673667

3668-
ss[j*TS + 2*C + j ] = ms0;
3669-
ss[j*TS + 2*C + j + sg*SH] = ms1;
3670-
}
3668+
ss[j*TS + 2*C + j - 1*SH] = ms0;
3669+
ss[j*TS + 2*C + j ] = ms1;
36713670
}
36723671

3672+
//simdgroup_barrier(mem_flags::mem_threadgroup);
3673+
36733674
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
36743675
{
36753676
s8x8_t ms0;
36763677
s8x8_t ms1;
36773678

3678-
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3679-
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3679+
simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
3680+
simdgroup_load(ms1, ss + 2*C, TS, 0, false);
36803681

36813682
#pragma unroll(DV8)
36823683
for (short i = 0; i < DV8; ++i) {
3683-
o8x8_t t;
3684+
simdgroup_float8x8 t;
36843685

36853686
simdgroup_load (t, so + i*8, DV, 0, false);
3686-
simdgroup_multiply(t, ms1, t);
3687+
simdgroup_multiply(t, ms0, t);
36873688

3688-
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
3689+
simdgroup_multiply_accumulate(t, ms1, lo[i], t);
3690+
simdgroup_store(t, so + i*8, DV, 0, false);
36893691
}
36903692
}
36913693
}
3692-
}
36933694

3694-
// store result to shared memory (reuse sq)
3695-
if (sgitg == 0) {
3696-
for (short i = 0; i < DV8; ++i) {
3697-
simdgroup_store(lo[i], so + i*8, DV, 0, false);
3698-
}
3695+
threadgroup_barrier(mem_flags::mem_threadgroup);
36993696
}
37003697

3701-
threadgroup_barrier(mem_flags::mem_threadgroup);
3702-
3703-
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
3698+
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
37043699

37053700
// final rescale with 1/S and store to global memory
37063701
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
@@ -3723,17 +3718,17 @@ kernel void kernel_flash_attn_ext(
37233718
half, half4x4, simdgroup_half8x8, \
37243719
float, simdgroup_float8x8, \
37253720
float, simdgroup_float8x8, \
3726-
float, float4, simdgroup_float8x8
3727-
//half, half4, simdgroup_half8x8
3721+
half, half4, simdgroup_half8x8
3722+
//float, float4, simdgroup_float8x8
37283723

37293724
#define FA_TYPES_BF \
37303725
bfloat, bfloat4, simdgroup_bfloat8x8, \
37313726
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
37323727
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
37333728
float, simdgroup_float8x8, \
37343729
float, simdgroup_float8x8, \
3735-
float, float4, simdgroup_float8x8
3736-
//half, half4, simdgroup_half8x8
3730+
half, half4, simdgroup_half8x8
3731+
//float, float4, simdgroup_float8x8
37373732

37383733
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
37393734

0 commit comments

Comments
 (0)