Skip to content

metal : use F32 prec in FA kernels #12688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -4179,7 +4179,7 @@ static void ggml_metal_encode_node(
// ne00*(nsg)
// each simdgroup has a full f16 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 2*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))

int64_t nsgmax = 2;
while (true) {
Expand Down
94 changes: 47 additions & 47 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
threadgroup_barrier(mem_flags::mem_threadgroup);

{
half S[Q] = { [0 ... Q-1] = 0.0f };
half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
float S[Q] = { [0 ... Q-1] = 0.0f };
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };

// thread indices inside the simdgroup
// TODO: see if we can utilize quad-group functions for better performance
Expand All @@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(

const bool has_mask = mask != q;

half slope = 1.0f;
float slope = 1.0f;

// ALiBi
if (args.max_bias > 0.0f) {
const short h = iq2;

const half base = h < args.n_head_log2 ? args.m0 : args.m1;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;

slope = pow(base, exph);
Expand All @@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(

if (has_mask) {
// used to detect blocks full of -INF
half smax = -INFINITY;
float smax = -INFINITY;

// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);

const half m = pm[ic + tiisg];
const float m = pm[ic + tiisg];

ss[j*TS + C + tiisg] = m;
smax = max(smax, m);
Expand Down Expand Up @@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
// online softmax
{
for (ushort j = 0; j < Q; ++j) {
const half m = M[j];
const float m = M[j];

// scale and apply the logitcap / mask
half s = ss[j*TS + tiisg]*args.scale;
float s = ss[j*TS + tiisg]*args.scale;

if (args.logit_softcap != 0.0f) {
s = args.logit_softcap*precise::tanh(s);
Expand All @@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(

M[j] = simd_max(max(M[j], s));

const half ms = exp(m - M[j]);
const half vs = exp(s - M[j]);
const float ms = exp(m - M[j]);
const float vs = exp(s - M[j]);

S[j] = S[j]*ms + simd_sum(vs);

Expand Down Expand Up @@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(

// reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) {
half S = { 0.0f };
half M = { -__FLT16_MAX__/2 };
float S = { 0.0f };
float M = { -__FLT16_MAX__/2 };

threadgroup_barrier(mem_flags::mem_threadgroup);

Expand All @@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
// the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) {
for (short j = 0; j < Q; ++j) {
const half S0 = ss[j*TS + 0];
const half S1 = ss[j*TS + sg*SH + 0];
const float S0 = ss[j*TS + 0];
const float S1 = ss[j*TS + sg*SH + 0];

const half M0 = ss[j*TS + 1];
const half M1 = ss[j*TS + sg*SH + 1];
const float M0 = ss[j*TS + 1];
const float M1 = ss[j*TS + sg*SH + 1];

M = max(M0, M1);

const half ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M);
const float ms0 = exp(M0 - M);
const float ms1 = exp(M1 - M);

S = S0*ms0 + S1*ms1;

Expand Down Expand Up @@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
constexpr short DV4 = DV/4;
constexpr short NW = N_SIMDWIDTH;
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
constexpr short SH = 2*C; // shared memory per simdgroup
constexpr short SH = 4*C; // shared memory per simdgroup

const short T = DK + nsg*SH; // shared memory size per query in (half)

//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results

// store the result for all queries in local memory (the O matrix from the paper)
o4_t lo[DV4/NL];
Expand Down Expand Up @@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup_barrier(mem_flags::mem_threadgroup);

{
half S = 0.0f;
half M = -__FLT16_MAX__/2;
float S = 0.0f;
float M = -__FLT16_MAX__/2;

// thread indices inside the simdgroup
const short tx = tiisg%NL;
Expand All @@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
// pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31);

half slope = 1.0f;
float slope = 1.0f;

// ALiBi
if (args.max_bias > 0.0f) {
const short h = iq2;

const half base = h < args.n_head_log2 ? args.m0 : args.m1;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;

slope = pow(base, exph);
Expand Down Expand Up @@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(

// online softmax
{
const half m = M;
const half s = ss[tiisg];
const float m = M;
const float s = ss[tiisg];

M = simd_max(max(M, s));

const half ms = exp(m - M);
const half vs = exp(s - M);
const float ms = exp(m - M);
const float vs = exp(s - M);

S = S*ms + simd_sum(vs);

Expand Down Expand Up @@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
v4_t mv;
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);

lo[ii/NL] += mv*ms;
lo[ii/NL] += o4_t(float4(mv)*float4(ms));
}
}
}
Expand Down Expand Up @@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
// parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) {
const half S0 = ss[ 0];
const half S1 = ss[r*SH + 0];
const float S0 = ss[ 0];
const float S1 = ss[r*(SH/2) + 0];

const half M0 = ss[ 1];
const half M1 = ss[r*SH + 1];
const float M0 = ss[ 1];
const float M1 = ss[r*(SH/2) + 1];

const half M = max(M0, M1);
const float M = max(M0, M1);

const half ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M);
const float ms0 = exp(M0 - M);
const float ms1 = exp(M1 - M);

const half S = S0*ms0 + S1*ms1;
const float S = S0*ms0 + S1*ms1;

if (tiisg == 0) {
ss[0] = S;
Expand Down Expand Up @@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
//
#define FA_TYPES \
half4, \
half4, \
half4, \
float, \
half, half4, \
half4, \
half4, \
half4, \
float, \
float, float4, \
half4

typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
Expand Down
Loading