@@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16(
335
335
for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
336
336
const int j = j0 + threadIdx .y ;
337
337
338
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2 *WARP_SIZE)];
339
+ #pragma unroll
340
+ for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE/2 ; k0 += WARP_SIZE) {
341
+ const int k = k0 + threadIdx .x ;
342
+
343
+ KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2 ) + k];
344
+ }
345
+
338
346
half2 KQ_max_new = KQ_max[j0/nwarps];
339
347
#pragma unroll
340
348
for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE/2 ; k0 += WARP_SIZE) {
341
349
const int k = k0 + threadIdx .x ;
342
- half2 val = KQ2[j*(kqs_padded/2 ) + k];
343
- val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
344
- KQ_max_new = __hmax2 (KQ_max_new, val);
345
- KQ2[j*(kqs_padded/2 ) + k] = val;
350
+
351
+ KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
352
+ KQ_max_new = __hmax2 (KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
346
353
}
347
354
KQ_max_new = __half2half2 (warp_reduce_max (__hmax (__low2half (KQ_max_new), __high2half (KQ_max_new))));
348
355
const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
@@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16(
356
363
for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE/2 ; k0 += WARP_SIZE) {
357
364
const int k = k0 + threadIdx .x ;
358
365
359
- half2 val = KQ2[j*(kqs_padded/2 ) + k];
360
- const half2 diff = val - KQ_max[j0/nwarps];
361
- val = h2exp (diff);
366
+ const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps];
367
+ KQ2_tmp[k0/WARP_SIZE] = h2exp (diff);
362
368
const uint ftz_mask = __hgt2_mask (diff, make_half2 (SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
363
- *((uint *) &val ) &= ftz_mask;
364
- KQ_rowsum_add += val ;
365
- KQ2[j*(kqs_padded/2 ) + k] = val ;
369
+ *((uint *) &KQ2_tmp[k0/WARP_SIZE] ) &= ftz_mask;
370
+ KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE] ;
371
+ KQ2[j*(kqs_padded/2 ) + k] = KQ2_tmp[k0/WARP_SIZE] ;
366
372
}
367
373
KQ_rowsum_add = warp_reduce_sum (KQ_rowsum_add);
368
374
0 commit comments