Skip to content

Commit 888ffc8

Browse files
committed
refractor mmqv to unify the calculation of nwarps and rows per block between host and device code.
1 parent 1a24c46 commit 888ffc8

File tree

2 files changed

+143
-59
lines changed

2 files changed

+143
-59
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,11 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
395395

396396
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
397397
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
398-
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
398+
#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
399399
c = __builtin_amdgcn_sdot4(a, b, c, false);
400400
#elif defined(RDNA3)
401401
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
402-
#elif defined(__gfx1010__) || defined(__gfx900__)
402+
#elif defined(RDNA1) || defined(__gfx900__)
403403
int tmp1;
404404
int tmp2;
405405
asm("\n \

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 141 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -47,36 +47,108 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
4747
1;
4848
}
4949

50+
static constexpr __device__ int get_device_table_id()
51+
{
52+
#if defined(RDNA2) || defined(RDNA3)
53+
return 2;
54+
#elif defined(GCN) || defined(CDNA)
55+
return 1;
56+
#else
57+
return 0;
58+
#endif
59+
}
60+
61+
static __host__ int get_device_table_id(int cc)
62+
{
63+
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
64+
return 2;
65+
}
66+
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
67+
return 1;
68+
}
69+
return 0;
70+
}
71+
72+
static constexpr int calc_nwarps(int ncols_y, int table_id)
73+
{
74+
if (table_id == 0) {
75+
switch (ncols_y) {
76+
case 1:
77+
case 2:
78+
case 3:
79+
case 4:
80+
return 4;
81+
case 5:
82+
case 6:
83+
case 7:
84+
case 8:
85+
return 2;
86+
default:
87+
return 1;
88+
}
89+
} else if(table_id == 1) {
90+
switch (ncols_y) {
91+
case 1:
92+
case 2:
93+
case 3:
94+
case 4:
95+
return 2;
96+
case 5:
97+
case 6:
98+
case 7:
99+
case 8:
100+
default:
101+
return 1;
102+
}
103+
}
104+
return 1;
105+
}
106+
107+
static constexpr int calc_rows_per_block(int ncols_y, int table_id)
108+
{
109+
if (table_id == 0 || table_id == 1) {
110+
switch (ncols_y) {
111+
case 1:
112+
return 1;
113+
case 2:
114+
case 3:
115+
case 4:
116+
case 5:
117+
case 6:
118+
case 7:
119+
case 8:
120+
return 2;
121+
default:
122+
return 1;
123+
}
124+
}
125+
return 1;
126+
}
127+
50128
template <ggml_type type, int ncols_y>
51-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52129
// tell the compiler to use as many registers as it wants, see nwarps definition below
53-
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
54-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
130+
__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
55131
static __global__ void mul_mat_vec_q(
56132
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57133
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
58134

59135
constexpr int qk = ggml_cuda_type_traits<type>::qk;
60136
constexpr int qi = ggml_cuda_type_traits<type>::qi;
61137
constexpr int vdr = get_vdr_mmvq(type);
138+
constexpr int table_id = get_device_table_id();
139+
constexpr int nwarps = calc_nwarps(ncols_y, table_id);
140+
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
141+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
62142

63143
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
64144

65-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66-
constexpr int nwarps = 1;
67-
constexpr int rows_per_cuda_block = 1;
68-
#else
69-
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
70-
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
71-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72-
73-
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
145+
const int tid = warp_size*threadIdx.y + threadIdx.x;
74146
const int row0 = rows_per_cuda_block*blockIdx.x;
75147
const int blocks_per_row_x = ncols_x / qk;
76148
const int blocks_per_col_y = nrows_y / QK8_1;
77-
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
149+
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
78150

79-
// partial sum for each thread
151+
// partial sum for each thread
80152
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
81153

82154
const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +168,7 @@ static __global__ void mul_mat_vec_q(
96168
}
97169
}
98170

99-
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
171+
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
100172
if (threadIdx.y > 0) {
101173
#pragma unroll
102174
for (int j = 0; j < ncols_y; ++j) {
@@ -120,7 +192,7 @@ static __global__ void mul_mat_vec_q(
120192
for (int l = 0; l < nwarps-1; ++l) {
121193
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
122194
}
123-
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
195+
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
124196
}
125197

126198
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
@@ -129,73 +201,85 @@ static __global__ void mul_mat_vec_q(
129201
}
130202
}
131203

204+
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, int table_id)
205+
{
206+
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
207+
const dim3 block_nums(nblocks, 1, 1);
208+
const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
209+
return {block_nums, block_dims};
210+
}
211+
132212
template <ggml_type type>
133213
static void mul_mat_vec_q_cuda(
134214
const void * vx, const void * vy, float * dst,
135215
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
216+
int device;
217+
int warp_size;
136218

137219
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
138220
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
139221

140-
int id = ggml_cuda_get_device();
141-
142-
int64_t nwarps = 1;
143-
int64_t rows_per_cuda_block = 1;
144-
145-
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146-
switch(ncols_y) {
147-
case 1:
148-
nwarps = 4;
149-
rows_per_cuda_block = 1;
150-
break;
151-
case 2:
152-
case 3:
153-
case 4:
154-
nwarps = 4;
155-
rows_per_cuda_block = 2;
156-
break;
157-
case 5:
158-
case 6:
159-
case 7:
160-
case 8:
161-
nwarps = 2;
162-
rows_per_cuda_block = 2;
163-
break;
164-
default:
165-
GGML_ABORT("fatal error");
166-
break;
167-
}
168-
}
169-
170-
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
171-
const dim3 block_nums(nblocks, 1, 1);
172-
const dim3 block_dims(WARP_SIZE, nwarps, 1);
222+
CUDA_CHECK(cudaGetDevice(&device));
223+
warp_size = ggml_cuda_info().devices[device].warp_size;
224+
int table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
173225

174226
switch (ncols_y) {
175227
case 1:
176-
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
228+
{
229+
constexpr int c_ncols_y = 1;
230+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
231+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
177232
break;
233+
}
178234
case 2:
179-
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
235+
{
236+
constexpr int c_ncols_y = 2;
237+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
238+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
180239
break;
240+
}
181241
case 3:
182-
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
242+
{
243+
constexpr int c_ncols_y = 3;
244+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
245+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
183246
break;
247+
}
184248
case 4:
185-
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
249+
{
250+
constexpr int c_ncols_y = 4;
251+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
252+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
186253
break;
254+
}
187255
case 5:
188-
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
256+
{
257+
constexpr int c_ncols_y = 5;
258+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
259+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
189260
break;
261+
}
190262
case 6:
191-
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
263+
{
264+
constexpr int c_ncols_y = 6;
265+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
266+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
192267
break;
268+
}
193269
case 7:
194-
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
270+
{
271+
constexpr int c_ncols_y = 7;
272+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
273+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
195274
break;
275+
}
196276
case 8:
197-
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
277+
{
278+
constexpr int c_ncols_y = 8;
279+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
280+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
198281
break;
282+
}
199283
default:
200284
GGML_ABORT("fatal error");
201285
break;

0 commit comments

Comments
 (0)