Skip to content

Commit b1ca685

Browse files
hjc4869JohannesGaesslerbjj
authored andcommitted
HIP: implement FlashAttention via rocWMMA for CDNA and RDNA3+ (ggml-org#12032)
Adds GGML_HIP_ROCWMMA_FATTN and rocwmma header check Adds rocWMMA support to fattn-wmma-f16 --- Signed-off-by: Carl Klemm <[email protected]> Co-authored-by: Johannes Gäßler <[email protected]> Co-authored-by: Ben Jackson <[email protected]>
1 parent 4789f94 commit b1ca685

File tree

6 files changed

+145
-95
lines changed

6 files changed

+145
-95
lines changed

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balan
162162
option(GGML_HIP "ggml: use HIP" OFF)
163163
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
164164
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
165+
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
165166
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
166167
option(GGML_VULKAN "ggml: use Vulkan" OFF)
167168
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)

ggml/src/ggml-cuda/common.cuh

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
6363
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
6464

65+
#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
6566
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
6667
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
6768
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
@@ -196,6 +197,10 @@ typedef float2 dfloat2;
196197
#define FP16_MMA_AVAILABLE
197198
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
198199

200+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
201+
#define FP16_MMA_AVAILABLE
202+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
203+
199204
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
200205
#define NEW_MMA_AVAILABLE
201206
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -223,12 +228,18 @@ static bool fast_fp16_hardware_available(const int cc) {
223228

224229
// Any FP16 tensor core instructions are available for ggml code.
225230
static bool fp16_mma_available(const int cc) {
226-
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
231+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
232+
return false;
233+
#else
234+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
235+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
236+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
227237
}
228238

229239
// To be used for feature selection of external libraries, e.g. cuBLAS.
230240
static bool fp16_mma_hardware_available(const int cc) {
231-
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
241+
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
242+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
232243
}
233244

234245
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,35 +57,36 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
5757
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
5858

5959
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
6061
GGML_UNUSED(Q_v);
6162

6263
T sum = 0.0f;
6364

6465
#pragma unroll
65-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
66+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
6667
const int k_KQ = k_KQ_0 + threadIdx.x;
6768

6869
const int ib = k_KQ / QI8_1;
6970
const int iqs4 = k_KQ % QI4_0;
7071
const int shift = k_KQ & (QI8_1/2);
7172

7273
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
73-
const int u = Q_q8[k_KQ_0/WARP_SIZE];
74+
const int u = Q_q8[k_KQ_0/warp_size];
7475

7576
const int sumi = ggml_cuda_dp4a(v, u, 0);
7677

7778
#ifdef FP16_AVAILABLE
7879
if (std::is_same<T, half>::value) {
7980
const half2 * Q_ds = (const half2 *) Q_ds_v;
8081

81-
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
82+
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
8283
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
8384
} else
8485
#endif // FP16_AVAILABLE
8586
{
8687
const float2 * Q_ds = (const float2 *) Q_ds_v;
8788

88-
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
89+
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
8990
}
9091
}
9192

@@ -97,37 +98,38 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
9798
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
9899

99100
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
101+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
100102
GGML_UNUSED(Q_v);
101103

102104
T sum = 0.0f;
103105

104106
#pragma unroll
105-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
107+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
106108
const int k_KQ = k_KQ_0 + threadIdx.x;
107109

108110
const int ib = k_KQ / QI8_1;
109111
const int iqs4 = k_KQ % QI4_1;
110112
const int shift = k_KQ & (QI8_1/2);
111113

112114
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
113-
const int u = Q_q8[k_KQ_0/WARP_SIZE];
115+
const int u = Q_q8[k_KQ_0/warp_size];
114116

115117
const int sumi = ggml_cuda_dp4a(v, u, 0);
116118

117119
#ifdef FP16_AVAILABLE
118120
if (std::is_same<T, half>::value) {
119121
const half2 * Q_ds = (const half2 *) Q_ds_v;
120122

121-
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
123+
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
122124
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
123125
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
124126
} else
125127
#endif // FP16_AVAILABLE
126128
{
127129
const float2 * Q_ds = (const float2 *) Q_ds_v;
128130

129-
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
130-
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
131+
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
132+
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
131133

132134
sum += (T) (sumid4d8 + m4s8scaled);
133135
}
@@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
141143
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
142144

143145
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
146+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
144147
GGML_UNUSED(Q_v);
145148

146149
T sum = 0.0f;
147150

148151
#pragma unroll
149-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
152+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
150153
const int k_KQ = k_KQ_0 + threadIdx.x;
151154

152155
const int ib = k_KQ / QI8_1;
@@ -161,22 +164,22 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
161164
v |= (vh << 18) & 0x00100000; // 2 -> 20
162165
v |= (vh << 25) & 0x10000000; // 3 -> 28
163166

164-
const int u = Q_q8[k_KQ_0/WARP_SIZE];
167+
const int u = Q_q8[k_KQ_0/warp_size];
165168

166169
const int sumi = ggml_cuda_dp4a(v, u, 0);
167170

168171
#ifdef FP16_AVAILABLE
169172
if (std::is_same<T, half>::value) {
170173
const half2 * Q_ds = (const half2 *) Q_ds_v;
171174

172-
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
175+
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
173176
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
174177
} else
175178
#endif // FP16_AVAILABLE
176179
{
177180
const float2 * Q_ds = (const float2 *) Q_ds_v;
178181

179-
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
182+
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
180183
}
181184
}
182185

@@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
188191
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
189192

190193
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
194+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
191195
GGML_UNUSED(Q_v);
192196

193197
T sum = 0.0f;
194198

195199
#pragma unroll
196-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
200+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
197201
const int k_KQ = k_KQ_0 + threadIdx.x;
198202

199203
const int ib = k_KQ / QI8_1;
@@ -208,24 +212,24 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
208212
v |= (vh << 18) & 0x00100000; // 2 -> 20
209213
v |= (vh << 25) & 0x10000000; // 3 -> 28
210214

211-
const int u = Q_q8[k_KQ_0/WARP_SIZE];
215+
const int u = Q_q8[k_KQ_0/warp_size];
212216

213217
const int sumi = ggml_cuda_dp4a(v, u, 0);
214218

215219
#ifdef FP16_AVAILABLE
216220
if (std::is_same<T, half>::value) {
217221
const half2 * Q_ds = (const half2 *) Q_ds_v;
218222

219-
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
223+
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
220224
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
221225
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
222226
} else
223227
#endif // FP16_AVAILABLE
224228
{
225229
const float2 * Q_ds = (const float2 *) Q_ds_v;
226230

227-
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
228-
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
231+
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
232+
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
229233

230234
sum += (T) (sumid5d8 + m5s8scaled);
231235
}
@@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
239243
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
240244

241245
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
246+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
242247
GGML_UNUSED(Q_v);
243248

244249
T sum = 0.0f;
245250

246251
#pragma unroll
247-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
252+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
248253
const int k_KQ = k_KQ_0 + threadIdx.x;
249254

250255
const int ib = k_KQ / QI8_0;
@@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
255260
T Q_d;
256261
if (std::is_same<T, half>::value) {
257262
const half2 * Q_ds = (const half2 *) Q_ds_v;
258-
Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
263+
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
259264
} else {
260265
const float2 * Q_ds = (const float2 *) Q_ds_v;
261-
Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
266+
Q_d = Q_ds[k_KQ_0/warp_size].x;
262267
}
263268

264-
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
269+
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
265270
}
266271

267272
return sum;
@@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
272277
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
273278

274279
const half2 * K_h2 = (const half2 *) K_c;
280+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
275281
GGML_UNUSED(Q_q8);
276282
GGML_UNUSED(Q_ds_v);
277283

@@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
282288
half2 sum2 = make_half2(0.0f, 0.0f);
283289

284290
#pragma unroll
285-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
291+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
286292
const int k_KQ = k_KQ_0 + threadIdx.x;
287293

288294
const half2 K_ik = K_h2[k_KQ];
289-
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
295+
sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
290296
}
291297

292298
return __low2half(sum2) + __high2half(sum2);
@@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
298304
float sum = 0.0f;
299305

300306
#pragma unroll
301-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
307+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
302308
const int k_KQ = k_KQ_0 + threadIdx.x;
303309

304310
const half2 K_ik = K_h2[k_KQ];
305-
sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
306-
sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
311+
sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
312+
sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
307313
}
308314

309315
return sum;
@@ -698,6 +704,8 @@ void launch_fattn(
698704

699705
GGML_ASSERT(Q->ne[3] == 1);
700706

707+
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
708+
701709
ggml_cuda_pool & pool = ctx.pool();
702710
cudaStream_t main_stream = ctx.stream();
703711
const int id = ggml_cuda_get_device();
@@ -750,7 +758,7 @@ void launch_fattn(
750758
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
751759
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
752760

753-
const dim3 block_dim(WARP_SIZE, nwarps, 1);
761+
const dim3 block_dim(warp_size, nwarps, 1);
754762
dim3 blocks_num;
755763
if (parallel_blocks == 0) {
756764
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
@@ -796,6 +804,8 @@ void launch_fattn(
796804
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
797805
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
798806

807+
GGML_ASSERT(block_dim.x % warp_size == 0);
808+
GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
799809
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
800810
(const char *) Q->data,
801811
K_data,

0 commit comments

Comments
 (0)