Skip to content

Commit 50d4277

Browse files
committed
refractor mmqv to unify the calculation of nwarps and rows per block between host and device code.
1 parent 1a24c46 commit 50d4277

File tree

2 files changed

+128
-59
lines changed

2 files changed

+128
-59
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,11 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
395395

396396
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
397397
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
398-
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
398+
#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
399399
c = __builtin_amdgcn_sdot4(a, b, c, false);
400400
#elif defined(RDNA3)
401401
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
402-
#elif defined(__gfx1010__) || defined(__gfx900__)
402+
#elif defined(RDNA1) || defined(__gfx900__)
403403
int tmp1;
404404
int tmp2;
405405
asm("\n \

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 126 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -47,36 +47,93 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
4747
1;
4848
}
4949

50+
static constexpr __device__ int get_device_table()
51+
{
52+
#if defined(RDNA2) || defined(RDNA3)
53+
return 1;
54+
#else
55+
return 0;
56+
#endif // defined(RDNA2) || defined(RDNA3)
57+
}
58+
59+
static __host__ int get_device_table(int cc)
60+
{
61+
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
62+
return 1;
63+
}
64+
65+
return 0;
66+
}
67+
68+
static constexpr int calc_nwarps(int ncols_y, int table_id)
69+
{
70+
if (table_id == 0)
71+
{
72+
switch (ncols_y) {
73+
case 1:
74+
case 2:
75+
case 3:
76+
case 4:
77+
return 2;
78+
case 5:
79+
case 6:
80+
case 7:
81+
case 8:
82+
return 4;
83+
default:
84+
return 1;
85+
}
86+
} else {
87+
return 1;
88+
}
89+
}
90+
91+
static constexpr int calc_rows_per_block(int ncols_y, int table_id)
92+
{
93+
if (table_id == 0) {
94+
switch (ncols_y) {
95+
case 1:
96+
return 1;
97+
case 2:
98+
case 3:
99+
case 4:
100+
case 5:
101+
case 6:
102+
case 7:
103+
case 8:
104+
return 2;
105+
default:
106+
return 1;
107+
}
108+
} else {
109+
return 1;
110+
}
111+
}
112+
50113
template <ggml_type type, int ncols_y>
51-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52114
// tell the compiler to use as many registers as it wants, see nwarps definition below
53-
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
54-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
115+
__launch_bounds__(calc_nwarps(ncols_y, get_device_table())*ggml_cuda_get_physical_warp_size(), 4)
55116
static __global__ void mul_mat_vec_q(
56117
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57118
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
58119

59120
constexpr int qk = ggml_cuda_type_traits<type>::qk;
60121
constexpr int qi = ggml_cuda_type_traits<type>::qi;
61122
constexpr int vdr = get_vdr_mmvq(type);
123+
constexpr int table_id = get_device_table();
124+
constexpr int nwarps = calc_nwarps(ncols_y, table_id);
125+
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
126+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
62127

63128
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
64129

65-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66-
constexpr int nwarps = 1;
67-
constexpr int rows_per_cuda_block = 1;
68-
#else
69-
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
70-
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
71-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72-
73-
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
130+
const int tid = warp_size*threadIdx.y + threadIdx.x;
74131
const int row0 = rows_per_cuda_block*blockIdx.x;
75132
const int blocks_per_row_x = ncols_x / qk;
76133
const int blocks_per_col_y = nrows_y / QK8_1;
77-
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
134+
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
78135

79-
// partial sum for each thread
136+
// partial sum for each thread
80137
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
81138

82139
const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +153,7 @@ static __global__ void mul_mat_vec_q(
96153
}
97154
}
98155

99-
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
156+
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
100157
if (threadIdx.y > 0) {
101158
#pragma unroll
102159
for (int j = 0; j < ncols_y; ++j) {
@@ -120,7 +177,7 @@ static __global__ void mul_mat_vec_q(
120177
for (int l = 0; l < nwarps-1; ++l) {
121178
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
122179
}
123-
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
180+
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
124181
}
125182

126183
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
@@ -129,73 +186,85 @@ static __global__ void mul_mat_vec_q(
129186
}
130187
}
131188

189+
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, int table_id)
190+
{
191+
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
192+
const dim3 block_nums(nblocks, 1, 1);
193+
const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
194+
return {block_nums, block_dims};
195+
}
196+
132197
template <ggml_type type>
133198
static void mul_mat_vec_q_cuda(
134199
const void * vx, const void * vy, float * dst,
135200
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
201+
int device;
202+
int warp_size;
136203

137204
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
138205
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
139206

140-
int id = ggml_cuda_get_device();
141-
142-
int64_t nwarps = 1;
143-
int64_t rows_per_cuda_block = 1;
144-
145-
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146-
switch(ncols_y) {
147-
case 1:
148-
nwarps = 4;
149-
rows_per_cuda_block = 1;
150-
break;
151-
case 2:
152-
case 3:
153-
case 4:
154-
nwarps = 4;
155-
rows_per_cuda_block = 2;
156-
break;
157-
case 5:
158-
case 6:
159-
case 7:
160-
case 8:
161-
nwarps = 2;
162-
rows_per_cuda_block = 2;
163-
break;
164-
default:
165-
GGML_ABORT("fatal error");
166-
break;
167-
}
168-
}
169-
170-
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
171-
const dim3 block_nums(nblocks, 1, 1);
172-
const dim3 block_dims(WARP_SIZE, nwarps, 1);
207+
CUDA_CHECK(cudaGetDevice(&device));
208+
warp_size = ggml_cuda_info().devices[device].warp_size;
209+
int table_id = get_device_table(ggml_cuda_info().devices[device].cc);
173210

174211
switch (ncols_y) {
175212
case 1:
176-
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
213+
{
214+
constexpr int c_ncols_y = 1;
215+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
216+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
177217
break;
218+
}
178219
case 2:
179-
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
220+
{
221+
constexpr int c_ncols_y = 2;
222+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
223+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
180224
break;
225+
}
181226
case 3:
182-
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
227+
{
228+
constexpr int c_ncols_y = 3;
229+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
230+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
183231
break;
232+
}
184233
case 4:
185-
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
234+
{
235+
constexpr int c_ncols_y = 4;
236+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
237+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
186238
break;
239+
}
187240
case 5:
188-
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
241+
{
242+
constexpr int c_ncols_y = 5;
243+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
244+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
189245
break;
246+
}
190247
case 6:
191-
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
248+
{
249+
constexpr int c_ncols_y = 6;
250+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
251+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
192252
break;
253+
}
193254
case 7:
194-
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
255+
{
256+
constexpr int c_ncols_y = 7;
257+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
258+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
195259
break;
260+
}
196261
case 8:
197-
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
262+
{
263+
constexpr int c_ncols_y = 8;
264+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
265+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
198266
break;
267+
}
199268
default:
200269
GGML_ABORT("fatal error");
201270
break;

0 commit comments

Comments
 (0)