2
2
#include " fattn.cuh"
3
3
4
4
#include < cstdint>
5
+
6
+ #if FP16_MMA_AVAILABLE
7
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8
+ #include < rocwmma/rocwmma.hpp>
9
+ namespace wmma = rocwmma;
10
+ #else
5
11
#include < mma.h>
12
+ namespace wmma = nvcuda::wmma;
13
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
14
+ #endif // FP16_MMA_AVAILABLE
6
15
7
16
#define FATTN_KQ_STRIDE 256
8
17
#define HALF_MAX_HALF __float2half (65504 .0f /2 ) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
@@ -228,11 +237,11 @@ static __global__ void flash_attn_ext_f16(
228
237
constexpr int frag_m = ncols == 8 ? 32 : 16 ;
229
238
constexpr int frag_n = ncols == 8 ? 8 : 16 ;
230
239
static_assert (D % frag_m == 0 , " If ncols == 8 then D % frag_m must be 0." );
231
- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::matrix_a, frag_m, frag_n, 16 , half, nvcuda:: wmma::row_major> frag_a_K;
232
- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::matrix_a, frag_m, frag_n, 16 , half, nvcuda:: wmma::col_major> frag_a_V;
233
- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::matrix_b, frag_m, frag_n, 16 , half, nvcuda:: wmma::col_major> frag_b;
234
- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
235
- typedef nvcuda:: wmma::fragment<nvcuda:: wmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
240
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , half, wmma::row_major> frag_a_K;
241
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , half, wmma::col_major> frag_a_V;
242
+ typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16 , half, wmma::col_major> frag_b;
243
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
244
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
236
245
237
246
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
238
247
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -316,7 +325,7 @@ static __global__ void flash_attn_ext_f16(
316
325
for (int i0 = 0 ; i0 < D; i0 += 16 ) {
317
326
#pragma unroll
318
327
for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
319
- nvcuda:: wmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
328
+ wmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
320
329
}
321
330
}
322
331
@@ -330,20 +339,20 @@ static __global__ void flash_attn_ext_f16(
330
339
frag_c_KQ KQ_c[ncols/frag_n];
331
340
#pragma unroll
332
341
for (int j = 0 ; j < ncols/frag_n; ++j) {
333
- nvcuda:: wmma::fill_fragment (KQ_c[j], 0 .0f );
342
+ wmma::fill_fragment (KQ_c[j], 0 .0f );
334
343
}
335
344
#pragma unroll
336
345
for (int k_KQ_0 = 0 ; k_KQ_0 < D; k_KQ_0 += 16 ) {
337
346
frag_a_K K_a;
338
- nvcuda:: wmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
347
+ wmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
339
348
#pragma unroll
340
349
for (int j = 0 ; j < ncols/frag_n; ++j) {
341
- nvcuda:: wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
350
+ wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
342
351
}
343
352
}
344
353
#pragma unroll
345
354
for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
346
- nvcuda:: wmma::store_matrix_sync ((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], kqs_padded, nvcuda:: wmma::mem_col_major);
355
+ wmma::store_matrix_sync ((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
347
356
}
348
357
}
349
358
@@ -449,7 +458,7 @@ static __global__ void flash_attn_ext_f16(
449
458
#pragma unroll
450
459
for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16 ) {
451
460
const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
452
- nvcuda:: wmma::load_matrix_sync (
461
+ wmma::load_matrix_sync (
453
462
KQ_b[k0/(VKQ_ratio*16 )][j0/frag_n],
454
463
KQ + j0*(kqar*kqs_padded) + k,
455
464
kqar*kqs_padded);
@@ -461,18 +470,18 @@ static __global__ void flash_attn_ext_f16(
461
470
for (int i_VKQ_0 = 0 ; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
462
471
#pragma unroll
463
472
for (int j = 0 ; j < ncols/frag_n; ++j) {
464
- nvcuda:: wmma::fill_fragment (VKQ_c[i_VKQ_0/VKQ_stride][j], 0 .0f );
473
+ wmma::fill_fragment (VKQ_c[i_VKQ_0/VKQ_stride][j], 0 .0f );
465
474
}
466
475
467
476
#pragma unroll
468
477
for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16 ) {
469
478
const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
470
479
471
480
frag_a_V v_a;
472
- nvcuda:: wmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
481
+ wmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
473
482
#pragma unroll
474
483
for (int j = 0 ; j < ncols/frag_n; ++j) {
475
- nvcuda:: wmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
484
+ wmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
476
485
}
477
486
}
478
487
}
@@ -484,10 +493,10 @@ static __global__ void flash_attn_ext_f16(
484
493
for (int i_KQ_0 = 0 ; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
485
494
#pragma unroll
486
495
for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
487
- nvcuda:: wmma::store_matrix_sync (
496
+ wmma::store_matrix_sync (
488
497
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx .y /VKQ_ratio),
489
498
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
490
- D_padded, nvcuda:: wmma::mem_col_major);
499
+ D_padded, wmma::mem_col_major);
491
500
}
492
501
}
493
502
@@ -860,6 +869,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
860
869
return ;
861
870
}
862
871
872
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) // 32x8 tensor cores are not available on AMD.
863
873
if (Q->ne [1 ] <= 8 && Q->ne [0 ] % WARP_SIZE == 0 ) {
864
874
constexpr int cols_per_block = 8 ;
865
875
constexpr int nwarps = 4 ;
@@ -882,6 +892,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
882
892
}
883
893
return ;
884
894
}
895
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
885
896
886
897
if (Q->ne [1 ] <= 32 ) {
887
898
constexpr int cols_per_block = 16 ;
0 commit comments