3
3
4
4
#include < mma.h>
5
5
6
- #define FATTN_KQ_STRIDE 256
7
- #define HALF_MAX_HALF __float2half (65504 .0f /2 ) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
6
+ #define FATTN_KQ_STRIDE 256
7
+ #define HALF_MAX_HALF __float2half (65504 .0f /2 ) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
8
+ #define SOFTMAX_FTZ_THRESHOLD -20 .0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
8
9
9
10
template <int D, int parallel_blocks> // D == head size
10
11
__launch_bounds__ (((D + WARP_SIZE - 1 ) / WARP_SIZE)*WARP_SIZE, 1)
@@ -338,10 +339,16 @@ static __global__ void flash_attn_ext_f16(
338
339
#pragma unroll
339
340
for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE/2 ; k0 += WARP_SIZE) {
340
341
const int k = k0 + threadIdx .x ;
341
- KQ_max_new = __hmax2 (KQ_max_new, KQ2[j*(kqs_padded/2 ) + k]);
342
+ half2 val = KQ2[j*(kqs_padded/2 ) + k];
343
+ val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
344
+ KQ_max_new = __hmax2 (KQ_max_new, val);
345
+ KQ2[j*(kqs_padded/2 ) + k] = val;
342
346
}
343
347
KQ_max_new = __half2half2 (warp_reduce_max (__hmax (__low2half (KQ_max_new), __high2half (KQ_max_new))));
344
- KQ_max_scale[j0/nwarps] = h2exp (KQ_max[j0/nwarps] - KQ_max_new);
348
+ const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
349
+ KQ_max_scale[j0/nwarps] = h2exp (diff);
350
+ const uint ftz_mask = __hgt2_mask (diff, make_half2 (SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
351
+ *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask;
345
352
KQ_max[j0/nwarps] = KQ_max_new;
346
353
347
354
half2 KQ_rowsum_add = make_half2 (0 .0f , 0 .0f );
@@ -350,8 +357,10 @@ static __global__ void flash_attn_ext_f16(
350
357
const int k = k0 + threadIdx .x ;
351
358
352
359
half2 val = KQ2[j*(kqs_padded/2 ) + k];
353
- val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
354
- val = h2exp (val - KQ_max[j0/nwarps]);
360
+ const half2 diff = val - KQ_max[j0/nwarps];
361
+ val = h2exp (diff);
362
+ const uint ftz_mask = __hgt2_mask (diff, make_half2 (SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
363
+ *((uint *) &val) &= ftz_mask;
355
364
KQ_rowsum_add += val;
356
365
KQ2[j*(kqs_padded/2 ) + k] = val;
357
366
}
@@ -501,7 +510,10 @@ static __global__ void flash_attn_combine_results(
501
510
float VKQ_denominator = 0 .0f ;
502
511
#pragma unroll
503
512
for (int l = 0 ; l < parallel_blocks; ++l) {
504
- float KQ_max_scale = hexp (__low2half (meta[l]) - kqmax);
513
+ const half diff = __low2half (meta[l]) - kqmax;
514
+ float KQ_max_scale = hexp (diff);
515
+ const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half (SOFTMAX_FTZ_THRESHOLD));
516
+ *((uint *) &KQ_max_scale) &= ftz_mask;
505
517
506
518
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim .y *D + blockIdx .y *D + tid];
507
519
VKQ_denominator += KQ_max_scale * __high2float (meta[l]);
0 commit comments