Skip to content

Commit 230d0c6

Browse files
ggerganovpockers21
authored andcommitted
metal : fix floating-point range of attention scores in FA kernels (ggml-org#13090)
ggml-ci
1 parent 58e6fbc commit 230d0c6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
31923192

31933193
{
31943194
float S[Q] = { [0 ... Q-1] = 0.0f };
3195-
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
3195+
float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
31963196

31973197
// thread indices inside the simdgroup
31983198
// TODO: see if we can utilize quad-group functions for better performance
@@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
34523452
// reduce the warps sequentially
34533453
for (ushort sg = 1; sg < nsg; ++sg) {
34543454
float S = { 0.0f };
3455-
float M = { -__FLT16_MAX__/2 };
3455+
float M = { -__FLT_MAX__/2 };
34563456

34573457
threadgroup_barrier(mem_flags::mem_threadgroup);
34583458

@@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
36993699

37003700
{
37013701
float S = 0.0f;
3702-
float M = -__FLT16_MAX__/2;
3702+
float M = -__FLT_MAX__/2;
37033703

37043704
// thread indices inside the simdgroup
37053705
const short tx = tiisg%NL;

0 commit comments

Comments
 (0)