Skip to content

metal : more precise Q*K in FA vec kernel #10247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2942,6 +2942,7 @@ kernel void kernel_flash_attn_ext(
half smax = -INFINITY;

// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);

Expand All @@ -2968,7 +2969,7 @@ kernel void kernel_flash_attn_ext(
// we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));

#pragma unroll
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
k8x8_t mk;
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
Expand All @@ -2989,7 +2990,7 @@ kernel void kernel_flash_attn_ext(

simdgroup_barrier(mem_flags::mem_threadgroup);

#pragma unroll
#pragma unroll(4)
for (short k = 0; k < 4; ++k) {
k8x8_t mk;

Expand Down Expand Up @@ -3067,7 +3068,7 @@ kernel void kernel_flash_attn_ext(
s8x8_t mm;
simdgroup_load(mm, ss + 2*C, TS, 0, false);

#pragma unroll
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
simdgroup_multiply(lo[i], mm, lo[i]);
}
Expand All @@ -3082,7 +3083,8 @@ kernel void kernel_flash_attn_ext(
if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
#pragma unroll

#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
v8x8_t mv;
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
Expand All @@ -3103,7 +3105,7 @@ kernel void kernel_flash_attn_ext(

simdgroup_barrier(mem_flags::mem_threadgroup);

#pragma unroll
#pragma unroll(4)
for (short k = 0; k < 4; ++k) {
v8x8_t mv;

Expand Down Expand Up @@ -3196,6 +3198,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);

#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
o8x8_t t;

Expand Down Expand Up @@ -3413,6 +3416,7 @@ kernel void kernel_flash_attn_ext_vec(
// load the queries from shared memory into local memory
q4x4_t mq[D16/NL];

#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
mq[ii/NL] = sq4x4[ii + tx];
}
Expand Down Expand Up @@ -3454,17 +3458,23 @@ kernel void kernel_flash_attn_ext_vec(

device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));

#pragma unroll
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
const short i = ii + tx;

k4x4_t mk;
deq_k(pk + i/nl_k, i%nl_k, mk);

mqka[0] += dot(mq[ii/NL][0], mk[0]);
mqka[1] += dot(mq[ii/NL][1], mk[1]);
mqka[2] += dot(mq[ii/NL][2], mk[2]);
mqka[3] += dot(mq[ii/NL][3], mk[3]);
// note: this is less precise than the version below
//mqka[0] += dot(mq[ii/NL][0], mk[0]);
//mqka[1] += dot(mq[ii/NL][1], mk[1]);
//mqka[2] += dot(mq[ii/NL][2], mk[2]);
//mqka[3] += dot(mq[ii/NL][3], mk[3]);

mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
}

qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
Expand Down Expand Up @@ -3513,7 +3523,7 @@ kernel void kernel_flash_attn_ext_vec(
ss[tiisg] = vs;

// O = diag(ms)*O
#pragma unroll
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
lo[ii/NL] *= ms;
}
Expand All @@ -3523,13 +3533,12 @@ kernel void kernel_flash_attn_ext_vec(

// O = O + (Q*K^T)*V
{
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));

const s4x4_t ms(ss[4*cc + ty]);

#pragma unroll
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
const short i = ii + tx;

Expand Down
Loading