Skip to content

Commit ac6ae5d

Browse files
Fix flash-attn for AMD
1 parent 871fcb6 commit ac6ae5d

File tree

2 files changed

+101
-90
lines changed

2 files changed

+101
-90
lines changed

ggml-cuda/common.cuh

Lines changed: 74 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -232,80 +232,6 @@ typedef float dfloat; // dequantize float
232232
typedef float2 dfloat2;
233233
#endif //GGML_CUDA_F16
234234

235-
[[noreturn]]
236-
static __device__ void no_device_code(
237-
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
238-
239-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
240-
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
241-
file_name, line, function_name, arch);
242-
GGML_UNUSED(arch_list);
243-
#else
244-
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
245-
file_name, line, function_name, arch, arch_list);
246-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
247-
__trap();
248-
249-
GGML_UNUSED(no_device_code); // suppress unused function warning
250-
}
251-
252-
#ifdef __CUDA_ARCH__
253-
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
254-
#else
255-
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
256-
#endif // __CUDA_ARCH__
257-
258-
static __device__ __forceinline__ float warp_reduce_sum(float x) {
259-
#pragma unroll
260-
for (int mask = 16; mask > 0; mask >>= 1) {
261-
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
262-
}
263-
return x;
264-
}
265-
266-
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
267-
#pragma unroll
268-
for (int mask = 16; mask > 0; mask >>= 1) {
269-
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
270-
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
271-
}
272-
return a;
273-
}
274-
275-
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
276-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
277-
#pragma unroll
278-
for (int mask = 16; mask > 0; mask >>= 1) {
279-
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
280-
}
281-
return a;
282-
#else
283-
GGML_UNUSED(a);
284-
NO_DEVICE_CODE;
285-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
286-
}
287-
288-
static __device__ __forceinline__ float warp_reduce_max(float x) {
289-
#pragma unroll
290-
for (int mask = 16; mask > 0; mask >>= 1) {
291-
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
292-
}
293-
return x;
294-
}
295-
296-
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
297-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
298-
#pragma unroll
299-
for (int mask = 16; mask > 0; mask >>= 1) {
300-
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
301-
}
302-
return x;
303-
#else
304-
GGML_UNUSED(x);
305-
NO_DEVICE_CODE;
306-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
307-
}
308-
309235
#if CUDART_VERSION < 12000
310236
static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) {
311237
const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b));
@@ -397,6 +323,80 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
397323
}
398324
#endif // defined(GGML_USE_HIPBLAS)
399325

326+
#ifdef __CUDA_ARCH__
327+
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
328+
#else
329+
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
330+
#endif // __CUDA_ARCH__
331+
332+
[[noreturn]]
333+
static __device__ void no_device_code(
334+
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
335+
336+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
337+
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
338+
file_name, line, function_name, arch);
339+
GGML_UNUSED(arch_list);
340+
#else
341+
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
342+
file_name, line, function_name, arch, arch_list);
343+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
344+
__trap();
345+
346+
GGML_UNUSED(no_device_code); // suppress unused function warning
347+
}
348+
349+
static __device__ __forceinline__ float warp_reduce_sum(float x) {
350+
#pragma unroll
351+
for (int mask = 16; mask > 0; mask >>= 1) {
352+
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
353+
}
354+
return x;
355+
}
356+
357+
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
358+
#pragma unroll
359+
for (int mask = 16; mask > 0; mask >>= 1) {
360+
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
361+
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
362+
}
363+
return a;
364+
}
365+
366+
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
367+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
368+
#pragma unroll
369+
for (int mask = 16; mask > 0; mask >>= 1) {
370+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
371+
}
372+
return a;
373+
#else
374+
GGML_UNUSED(a);
375+
NO_DEVICE_CODE;
376+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
377+
}
378+
379+
static __device__ __forceinline__ float warp_reduce_max(float x) {
380+
#pragma unroll
381+
for (int mask = 16; mask > 0; mask >>= 1) {
382+
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
383+
}
384+
return x;
385+
}
386+
387+
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
388+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
389+
#pragma unroll
390+
for (int mask = 16; mask > 0; mask >>= 1) {
391+
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
392+
}
393+
return x;
394+
#else
395+
GGML_UNUSED(x);
396+
NO_DEVICE_CODE;
397+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
398+
}
399+
400400
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
401401
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
402402
#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \

ggml-cuda/fattn.cu

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22
#include "fattn.cuh"
33

44
#include <cstdint>
5+
6+
#if FP16_MMA_AVAILABLE
7+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8+
#include <rocwmma/rocwmma.hpp>
9+
namespace wmma = rocwmma;
10+
#else
511
#include <mma.h>
12+
namespace wmma = nvcuda::wmma;
13+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
14+
#endif // FP16_MMA_AVAILABLE
615

716
#define FATTN_KQ_STRIDE 256
817
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
@@ -228,11 +237,11 @@ static __global__ void flash_attn_ext_f16(
228237
constexpr int frag_m = ncols == 8 ? 32 : 16;
229238
constexpr int frag_n = ncols == 8 ? 8 : 16;
230239
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
231-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
232-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
233-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
234-
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
235-
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
240+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
241+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
242+
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
243+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
244+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
236245

237246
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
238247
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -316,7 +325,7 @@ static __global__ void flash_attn_ext_f16(
316325
for (int i0 = 0; i0 < D; i0 += 16) {
317326
#pragma unroll
318327
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
319-
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
328+
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
320329
}
321330
}
322331

@@ -330,20 +339,20 @@ static __global__ void flash_attn_ext_f16(
330339
frag_c_KQ KQ_c[ncols/frag_n];
331340
#pragma unroll
332341
for (int j = 0; j < ncols/frag_n; ++j) {
333-
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
342+
wmma::fill_fragment(KQ_c[j], 0.0f);
334343
}
335344
#pragma unroll
336345
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
337346
frag_a_K K_a;
338-
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
347+
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
339348
#pragma unroll
340349
for (int j = 0; j < ncols/frag_n; ++j) {
341-
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
350+
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
342351
}
343352
}
344353
#pragma unroll
345354
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
346-
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
355+
wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
347356
}
348357
}
349358

@@ -449,7 +458,7 @@ static __global__ void flash_attn_ext_f16(
449458
#pragma unroll
450459
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
451460
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
452-
nvcuda::wmma::load_matrix_sync(
461+
wmma::load_matrix_sync(
453462
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
454463
KQ + j0*(kqar*kqs_padded) + k,
455464
kqar*kqs_padded);
@@ -461,18 +470,18 @@ static __global__ void flash_attn_ext_f16(
461470
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
462471
#pragma unroll
463472
for (int j = 0; j < ncols/frag_n; ++j) {
464-
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
473+
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
465474
}
466475

467476
#pragma unroll
468477
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
469478
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
470479

471480
frag_a_V v_a;
472-
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
481+
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
473482
#pragma unroll
474483
for (int j = 0; j < ncols/frag_n; ++j) {
475-
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
484+
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
476485
}
477486
}
478487
}
@@ -484,10 +493,10 @@ static __global__ void flash_attn_ext_f16(
484493
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
485494
#pragma unroll
486495
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
487-
nvcuda::wmma::store_matrix_sync(
496+
wmma::store_matrix_sync(
488497
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
489498
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
490-
D_padded, nvcuda::wmma::mem_col_major);
499+
D_padded, wmma::mem_col_major);
491500
}
492501
}
493502

@@ -860,6 +869,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
860869
return;
861870
}
862871

872+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) // 32x8 tensor cores are not available on AMD.
863873
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
864874
constexpr int cols_per_block = 8;
865875
constexpr int nwarps = 4;
@@ -882,6 +892,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
882892
}
883893
return;
884894
}
895+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
885896

886897
if (Q->ne[1] <= 32) {
887898
constexpr int cols_per_block = 16;

0 commit comments

Comments
 (0)