Skip to content

Commit ab0e5de

Browse files
daniandthewebNeo Zhang
authored andcommitted
Define and optimize RDNA1 (ggml-org#8085)
1 parent 80ffd6e commit ab0e5de

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ typedef float2 dfloat2;
227227
#define RDNA2
228228
#endif
229229

230+
#if defined(__gfx1010__) || defined(__gfx1012__)
231+
#define RDNA1
232+
#endif
233+
230234
#ifndef __has_builtin
231235
#define __has_builtin(x) 0
232236
#endif

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,16 @@ static constexpr __device__ int get_mmq_x_max_device() {
6060
}
6161

6262
static constexpr int get_mmq_y_host(const int cc) {
63-
return int8_mma_available(cc) || cc >= CC_VOLTA ? 128 : 64;
63+
return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64);
6464
}
6565

6666
static constexpr __device__ int get_mmq_y_device() {
6767
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
68+
#if defined(RDNA1)
69+
return 64;
70+
#else
6871
return 128;
72+
#endif // defined RDNA1
6973
#else
7074
#if __CUDA_ARCH__ >= CC_VOLTA
7175
return 128;
@@ -2259,9 +2263,9 @@ static __device__ void mul_mat_q_process_tile(
22592263

22602264
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
22612265
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2262-
#if defined(RDNA3) || defined(RDNA2)
2266+
#if defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
22632267
__launch_bounds__(WARP_SIZE*nwarps, 2)
2264-
#endif // defined(RDNA3) || defined(RDNA2)
2268+
#endif // defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
22652269
#else
22662270
#if __CUDA_ARCH__ >= CC_VOLTA
22672271
__launch_bounds__(WARP_SIZE*nwarps, 1)

0 commit comments

Comments
 (0)