Skip to content

Commit 0c41e03

Browse files
authored
metal : gemma2 flash attention support (#9159)
1 parent f12ceac commit 0c41e03

File tree

3 files changed

+54
-44
lines changed

3 files changed

+54
-44
lines changed

ggml/src/ggml-metal.m

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -802,15 +802,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
802802
if (op->src[0]->ne[0] == 256) {
803803
return false;
804804
}
805-
{
806-
float logit_softcap;
807-
808-
memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
809-
810-
if (logit_softcap != 0.0f) {
811-
return false;
812-
}
813-
}
814805
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
815806
case GGML_OP_MUL_MAT:
816807
case GGML_OP_MUL_MAT_ID:
@@ -2633,9 +2624,14 @@ static enum ggml_status ggml_metal_graph_compute(
26332624

26342625
float scale;
26352626
float max_bias;
2627+
float logit_softcap;
2628+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2629+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2630+
memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
26362631

2637-
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2638-
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2632+
if (logit_softcap != 0.0f) {
2633+
scale /= logit_softcap;
2634+
}
26392635

26402636
const uint32_t n_head = src0->ne[2];
26412637
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@@ -2686,30 +2682,31 @@ static enum ggml_status ggml_metal_graph_compute(
26862682
} else {
26872683
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
26882684
}
2689-
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2690-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2691-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2692-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2693-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2694-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2695-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2696-
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2697-
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2698-
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2699-
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2700-
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2701-
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2702-
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2703-
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2704-
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2705-
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2706-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2707-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2708-
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
2709-
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2710-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2711-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2712-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2685+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2686+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2687+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2688+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2689+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2690+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2691+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2692+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2693+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2694+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2695+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2696+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2697+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2698+
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2699+
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2700+
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2701+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2702+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2703+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2704+
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
2705+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2706+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2707+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2708+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2709+
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
27132710

27142711
if (!use_vec_kernel) {
27152712
// half8x8 kernel

ggml/src/ggml-metal.metal

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,6 +1976,7 @@ typedef void (flash_attn_ext_f16_t)(
19761976
constant float & m0,
19771977
constant float & m1,
19781978
constant uint32_t & n_head_log2,
1979+
constant float & logit_softcap,
19791980
threadgroup half * shared,
19801981
uint3 tgpig[[threadgroup_position_in_grid]],
19811982
uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2014,6 +2015,7 @@ kernel void kernel_flash_attn_ext_f16(
20142015
constant float & m0,
20152016
constant float & m1,
20162017
constant uint32_t & n_head_log2,
2018+
constant float & logit_softcap,
20172019
threadgroup half * shared [[threadgroup(0)]],
20182020
uint3 tgpig[[threadgroup_position_in_grid]],
20192021
uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2142,14 +2144,19 @@ kernel void kernel_flash_attn_ext_f16(
21422144
const short tx = tiisg%4;
21432145
const short ty = tiisg/4;
21442146

2147+
// mqk = mqk*scale
2148+
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2149+
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2150+
2151+
if (logit_softcap != 0.0f) {
2152+
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
2153+
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
2154+
}
2155+
21452156
if (mask != q) {
2146-
// mqk = mqk*scale + mask*slope
2147-
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2148-
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
2149-
} else {
2150-
// mqk = mqk*scale
2151-
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2152-
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2157+
// mqk = mqk + mask*slope
2158+
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2159+
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
21532160
}
21542161
}
21552162
}
@@ -2345,6 +2352,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
23452352
constant float & m0,
23462353
constant float & m1,
23472354
constant uint32_t & n_head_log2,
2355+
constant float & logit_softcap,
23482356
threadgroup half * shared [[threadgroup(0)]],
23492357
uint3 tgpig[[threadgroup_position_in_grid]],
23502358
uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2479,7 +2487,13 @@ kernel void kernel_flash_attn_ext_vec_f16(
24792487

24802488
// mqk = mqk*scale + mask*slope
24812489
if (tiisg == 0) {
2482-
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
2490+
mqk *= scale;
2491+
2492+
if (logit_softcap != 0.0f) {
2493+
mqk = logit_softcap*precise::tanh(mqk);
2494+
}
2495+
2496+
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
24832497

24842498
ss4[cc] = mqk;
24852499
}

tests/test-backend-ops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2487,7 +2487,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
24872487
}
24882488

24892489
GGML_ABORT("fatal error");
2490-
return false;
24912490
}
24922491

24932492
static void usage(char ** argv) {

0 commit comments

Comments
 (0)