Skip to content

Commit a5b0e2d

Browse files
store temp KQ in registers
1 parent ef9e159 commit a5b0e2d

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

ggml-cuda/fattn.cu

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16(
335335
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
336336
const int j = j0 + threadIdx.y;
337337

338+
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
339+
#pragma unroll
340+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
341+
const int k = k0 + threadIdx.x;
342+
343+
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
344+
}
345+
338346
half2 KQ_max_new = KQ_max[j0/nwarps];
339347
#pragma unroll
340348
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
341349
const int k = k0 + threadIdx.x;
342-
half2 val = KQ2[j*(kqs_padded/2) + k];
343-
val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
344-
KQ_max_new = __hmax2(KQ_max_new, val);
345-
KQ2[j*(kqs_padded/2) + k] = val;
350+
351+
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
352+
KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
346353
}
347354
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
348355
const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
@@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16(
356363
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
357364
const int k = k0 + threadIdx.x;
358365

359-
half2 val = KQ2[j*(kqs_padded/2) + k];
360-
const half2 diff = val - KQ_max[j0/nwarps];
361-
val = h2exp(diff);
366+
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps];
367+
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
362368
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
363-
*((uint *) &val) &= ftz_mask;
364-
KQ_rowsum_add += val;
365-
KQ2[j*(kqs_padded/2) + k] = val;
369+
*((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
370+
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
371+
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
366372
}
367373
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
368374

0 commit comments

Comments
 (0)