Skip to content

Commit ef9e159

Browse files
flush softmax exp below threshold to 0
1 parent 6a3b842 commit ef9e159

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

ggml-cuda/fattn.cu

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
#include <mma.h>
55

6-
#define FATTN_KQ_STRIDE 256
7-
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
6+
#define FATTN_KQ_STRIDE 256
7+
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
8+
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
89

910
template<int D, int parallel_blocks> // D == head size
1011
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
@@ -338,10 +339,16 @@ static __global__ void flash_attn_ext_f16(
338339
#pragma unroll
339340
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
340341
const int k = k0 + threadIdx.x;
341-
KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]);
342+
half2 val = KQ2[j*(kqs_padded/2) + k];
343+
val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
344+
KQ_max_new = __hmax2(KQ_max_new, val);
345+
KQ2[j*(kqs_padded/2) + k] = val;
342346
}
343347
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
344-
KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new);
348+
const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
349+
KQ_max_scale[j0/nwarps] = h2exp(diff);
350+
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
351+
*((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask;
345352
KQ_max[j0/nwarps] = KQ_max_new;
346353

347354
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
@@ -350,8 +357,10 @@ static __global__ void flash_attn_ext_f16(
350357
const int k = k0 + threadIdx.x;
351358

352359
half2 val = KQ2[j*(kqs_padded/2) + k];
353-
val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
354-
val = h2exp(val - KQ_max[j0/nwarps]);
360+
const half2 diff = val - KQ_max[j0/nwarps];
361+
val = h2exp(diff);
362+
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
363+
*((uint *) &val) &= ftz_mask;
355364
KQ_rowsum_add += val;
356365
KQ2[j*(kqs_padded/2) + k] = val;
357366
}
@@ -501,7 +510,10 @@ static __global__ void flash_attn_combine_results(
501510
float VKQ_denominator = 0.0f;
502511
#pragma unroll
503512
for (int l = 0; l < parallel_blocks; ++l) {
504-
float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax);
513+
const half diff = __low2half(meta[l]) - kqmax;
514+
float KQ_max_scale = hexp(diff);
515+
const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD));
516+
*((uint *) &KQ_max_scale) &= ftz_mask;
505517

506518
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
507519
VKQ_denominator += KQ_max_scale * __high2float(meta[l]);

0 commit comments

Comments
 (0)