File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
3192
3192
3193
3193
{
3194
3194
float S[Q] = { [0 ... Q-1 ] = 0 .0f };
3195
- float M[Q] = { [0 ... Q-1 ] = -__FLT16_MAX__ /2 };
3195
+ float M[Q] = { [0 ... Q-1 ] = -__FLT_MAX__ /2 };
3196
3196
3197
3197
// thread indices inside the simdgroup
3198
3198
// TODO: see if we can utilize quad-group functions for better performance
@@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
3452
3452
// reduce the warps sequentially
3453
3453
for (ushort sg = 1 ; sg < nsg; ++sg) {
3454
3454
float S = { 0 .0f };
3455
- float M = { -__FLT16_MAX__ /2 };
3455
+ float M = { -__FLT_MAX__ /2 };
3456
3456
3457
3457
threadgroup_barrier (mem_flags::mem_threadgroup);
3458
3458
@@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
3699
3699
3700
3700
{
3701
3701
float S = 0 .0f ;
3702
- float M = -__FLT16_MAX__ /2 ;
3702
+ float M = -__FLT_MAX__ /2 ;
3703
3703
3704
3704
// thread indices inside the simdgroup
3705
3705
const short tx = tiisg%NL;
You can’t perform that action at this time.
0 commit comments