@@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16(
6445
6445
const int D16 = D/16 ;
6446
6446
const int Q16 = Q/16 ;
6447
6447
const int NW = WARP_SIZE;
6448
- const int SH = (C + Q); // shared memory per simdgroup in (half)
6448
+ const int SH = (C + 2 * Q); // shared memory per simdgroup in (half)
6449
6449
6450
6450
const int T = D + num_warps*SH; // shared memory size per query in (half)
6451
6451
const int T2 = T/2 ; // shared memory size per query in (half2)
@@ -6526,11 +6526,16 @@ static __global__ void flash_attn_ext_f16(
6526
6526
}
6527
6527
}
6528
6528
6529
- const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
6530
-
6531
6529
// pointer to the mask
6532
6530
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr ;
6533
6531
6532
+ // prepare diagonal scale matrix
6533
+ half16x16_b mscale;
6534
+ for (int i = 0 ; i < 16 ; ++i) {
6535
+ ss[i*T + i] = __float2half (scale);
6536
+ }
6537
+ nvcuda::wmma::load_matrix_sync (mscale, ss, T);
6538
+
6534
6539
// loop over the KV cache
6535
6540
// each simdgroup handles blocks of Q rows and C columns
6536
6541
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
@@ -6555,10 +6560,15 @@ static __global__ void flash_attn_ext_f16(
6555
6560
6556
6561
// mqk = mqk*scale + mask
6557
6562
for (int64_t j = 0 ; j < Q16; ++j) {
6558
- for (uint32_t i = 0 ; i < mqk[j].num_elements ; i++) {
6559
- // TODO: process mask
6560
- mqk[j].x [i] = __float2half (scale) * mqk[j].x [i];
6561
- }
6563
+ half16x16_a mqka;
6564
+ half16x16_acc mm;
6565
+
6566
+ // convert accumulator to matrix_a
6567
+ nvcuda::wmma::store_matrix_sync ( ss + 16 *j*T + 16 *cc, mqk[j], T, nvcuda::wmma::mem_row_major);
6568
+ nvcuda::wmma::load_matrix_sync (mqka, ss + 16 *j*T + 16 *cc, T);
6569
+
6570
+ nvcuda::wmma::load_matrix_sync (mm, mp + 16 *j*(nb31/sizeof (half)) + ic + 16 *cc, nb31/sizeof (half), nvcuda::wmma::mem_row_major);
6571
+ nvcuda::wmma::mma_sync (mqk[j], mqka, mscale, mm);
6562
6572
nvcuda::wmma::store_matrix_sync (ss + 16 *j*T + 16 *cc, mqk[j], T, nvcuda::wmma::mem_row_major);
6563
6573
}
6564
6574
}
@@ -6631,18 +6641,19 @@ static __global__ void flash_attn_ext_f16(
6631
6641
6632
6642
// O = diag(ms)*O
6633
6643
for (int64_t j = 0 ; j < Q16; ++j) {
6634
- // half16x16_a mm;
6635
- // half16x16_b zro ;
6644
+ half16x16_a mm;
6645
+ half16x16_b lob ;
6636
6646
6637
- // nvcuda::wmma::fill_fragment(zro, 0.0);
6638
- // nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
6647
+ nvcuda::wmma::load_matrix_sync (mm, ss + 16 *j*T + C + 16 *j, T);
6639
6648
6640
6649
for (int64_t i = 0 ; i < D16; ++i) {
6641
- // nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
6642
- for (uint32_t k = 0 ; k < 16 *16 ; k++) {
6643
- half tmp = ss[(16 *j + k%16 )*T + C + 16 *j + k%16 ];
6644
- lo[j][i].x [k] = tmp * lo[j][i].x [k];
6645
- }
6650
+ // convert accumulator to matrix_b
6651
+ // TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion
6652
+ nvcuda::wmma::store_matrix_sync ( ss + 16 *j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major);
6653
+ nvcuda::wmma::load_matrix_sync (lob, ss + 16 *j*T + C + Q, T);
6654
+
6655
+ nvcuda::wmma::fill_fragment (lo[j][i], 0.0 );
6656
+ nvcuda::wmma::mma_sync (lo[j][i], mm, lob, lo[j][i]);
6646
6657
}
6647
6658
}
6648
6659
@@ -6732,10 +6743,11 @@ static __global__ void flash_attn_ext_f16(
6732
6743
nvcuda::wmma::fill_fragment (t2, 0.0 );
6733
6744
nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6734
6745
nvcuda::wmma::mma_sync (t2, ms1, t, t2);
6735
- // store temporally 'lo' data
6736
- nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major);
6737
- // load 'lo' data into t
6738
- nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6746
+
6747
+ // convert accumulator to matrix_b
6748
+ nvcuda::wmma::store_matrix_sync ( sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major);
6749
+ nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6750
+
6739
6751
nvcuda::wmma::mma_sync (lo[j][i], ms0, t, t2);
6740
6752
}
6741
6753
}
@@ -10897,8 +10909,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10897
10909
10898
10910
GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
10899
10911
GGML_ASSERT (!mask || mask->backend == GGML_BACKEND_GPU);
10900
- GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 8 ) &&
10901
- " the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big" );
10912
+ GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
10913
+ " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
10902
10914
10903
10915
ggml_cuda_set_device (g_main_device);
10904
10916
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0 ];
@@ -10914,13 +10926,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10914
10926
10915
10927
const int nqpb = 16 ; // queries per block
10916
10928
const int ncpw = 32 ; // cache values per warp (does not work for other values)
10917
- // const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
10918
- const int nwarps = 1 ;
10929
+
10930
+ const int nwarps_max = 8 ; // TODO: we don't want to launch too much warps. how much is too much?
10931
+ const int nwarps = Q->ne [1 ] <= nqpb ? MAX (4 , MIN (K->ne [1 ]/ncpw, nwarps_max)) : 4 ;
10919
10932
10920
10933
dim3 blocks_num ((Q->ne [1 ] + nqpb - 1 ) / nqpb, Q->ne [2 ], Q->ne [3 ]);
10921
10934
dim3 block_dim (32 , nwarps, 1 );
10922
10935
10923
- int shmem = nqpb*(Q->ne [0 ] + nwarps*(ncpw + nqpb))*(sizeof (float )/2 );
10936
+ // TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling
10937
+ // try to avoid this
10938
+ const size_t shmem = nqpb*(Q->ne [0 ] + nwarps*(ncpw + 2 *nqpb))*(sizeof (float )/2 );
10939
+
10924
10940
switch (Q->ne [0 ])
10925
10941
{
10926
10942
case 16 :
0 commit comments