@@ -47,36 +47,108 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
47
47
1 ;
48
48
}
49
49
50
+ static constexpr __device__ int get_device_table_id ()
51
+ {
52
+ #if defined(RDNA2) || defined(RDNA3)
53
+ return 2 ;
54
+ #elif defined(GCN) || defined(CDNA)
55
+ return 1 ;
56
+ #else
57
+ return 0 ;
58
+ #endif
59
+ }
60
+
61
+ static __host__ int get_device_table_id (int cc)
62
+ {
63
+ if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3 (cc)) {
64
+ return 2 ;
65
+ }
66
+ if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
67
+ return 1 ;
68
+ }
69
+ return 0 ;
70
+ }
71
+
72
+ static constexpr int calc_nwarps (int ncols_y, int table_id)
73
+ {
74
+ if (table_id == 0 ) {
75
+ switch (ncols_y) {
76
+ case 1 :
77
+ case 2 :
78
+ case 3 :
79
+ case 4 :
80
+ return 4 ;
81
+ case 5 :
82
+ case 6 :
83
+ case 7 :
84
+ case 8 :
85
+ return 2 ;
86
+ default :
87
+ return 1 ;
88
+ }
89
+ } else if (table_id == 1 ) {
90
+ switch (ncols_y) {
91
+ case 1 :
92
+ case 2 :
93
+ case 3 :
94
+ case 4 :
95
+ return 2 ;
96
+ case 5 :
97
+ case 6 :
98
+ case 7 :
99
+ case 8 :
100
+ default :
101
+ return 1 ;
102
+ }
103
+ }
104
+ return 1 ;
105
+ }
106
+
107
+ static constexpr int calc_rows_per_block (int ncols_y, int table_id)
108
+ {
109
+ if (table_id == 0 || table_id == 1 ) {
110
+ switch (ncols_y) {
111
+ case 1 :
112
+ return 1 ;
113
+ case 2 :
114
+ case 3 :
115
+ case 4 :
116
+ case 5 :
117
+ case 6 :
118
+ case 7 :
119
+ case 8 :
120
+ return 2 ;
121
+ default :
122
+ return 1 ;
123
+ }
124
+ }
125
+ return 1 ;
126
+ }
127
+
50
128
template <ggml_type type, int ncols_y>
51
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52
129
// tell the compiler to use as many registers as it wants, see nwarps definition below
53
- __launch_bounds__ ((ncols_y <= 4 ? 4 : 2 )*WARP_SIZE, 1)
54
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
130
+ __launch_bounds__ (calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
55
131
static __global__ void mul_mat_vec_q(
56
132
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57
133
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
58
134
59
135
constexpr int qk = ggml_cuda_type_traits<type>::qk;
60
136
constexpr int qi = ggml_cuda_type_traits<type>::qi;
61
137
constexpr int vdr = get_vdr_mmvq (type);
138
+ constexpr int table_id = get_device_table_id ();
139
+ constexpr int nwarps = calc_nwarps (ncols_y, table_id);
140
+ constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_y, table_id);
141
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
62
142
63
143
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda (type);
64
144
65
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66
- constexpr int nwarps = 1 ;
67
- constexpr int rows_per_cuda_block = 1 ;
68
- #else
69
- constexpr int nwarps = ncols_y <= 4 ? 4 : 2 ;
70
- constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2 ;
71
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72
-
73
- const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
145
+ const int tid = warp_size*threadIdx .y + threadIdx .x ;
74
146
const int row0 = rows_per_cuda_block*blockIdx .x ;
75
147
const int blocks_per_row_x = ncols_x / qk;
76
148
const int blocks_per_col_y = nrows_y / QK8_1;
77
- constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
149
+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
78
150
79
- // partial sum for each thread
151
+ // partial sum for each thread
80
152
float tmp[ncols_y][rows_per_cuda_block] = {0 .0f };
81
153
82
154
const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +168,7 @@ static __global__ void mul_mat_vec_q(
96
168
}
97
169
}
98
170
99
- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][WARP_SIZE ];
171
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][warp_size ];
100
172
if (threadIdx .y > 0 ) {
101
173
#pragma unroll
102
174
for (int j = 0 ; j < ncols_y; ++j) {
@@ -120,7 +192,7 @@ static __global__ void mul_mat_vec_q(
120
192
for (int l = 0 ; l < nwarps-1 ; ++l) {
121
193
tmp[j][i] += tmp_shared[l][j][i][threadIdx .x ];
122
194
}
123
- tmp[j][i] = warp_reduce_sum (tmp[j][i]);
195
+ tmp[j][i] = warp_reduce_sum<warp_size> (tmp[j][i]);
124
196
}
125
197
126
198
if (threadIdx .x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx .x < nrows_dst)) {
@@ -129,73 +201,85 @@ static __global__ void mul_mat_vec_q(
129
201
}
130
202
}
131
203
204
+ static std::pair<dim3 , dim3 > calc_launch_params (const int ncols_y, const int nrows_x, const int warp_size, int table_id)
205
+ {
206
+ const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_y, table_id) - 1 ) / calc_rows_per_block (ncols_y, table_id);
207
+ const dim3 block_nums (nblocks, 1 , 1 );
208
+ const dim3 block_dims (warp_size, calc_nwarps (ncols_y, table_id), 1 );
209
+ return {block_nums, block_dims};
210
+ }
211
+
132
212
template <ggml_type type>
133
213
static void mul_mat_vec_q_cuda (
134
214
const void * vx, const void * vy, float * dst,
135
215
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
216
+ int device;
217
+ int warp_size;
136
218
137
219
GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
138
220
GGML_ASSERT (ncols_y <= MMVQ_MAX_BATCH_SIZE);
139
221
140
- int id = ggml_cuda_get_device ();
141
-
142
- int64_t nwarps = 1 ;
143
- int64_t rows_per_cuda_block = 1 ;
144
-
145
- if (ggml_cuda_info ().devices [id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146
- switch (ncols_y) {
147
- case 1 :
148
- nwarps = 4 ;
149
- rows_per_cuda_block = 1 ;
150
- break ;
151
- case 2 :
152
- case 3 :
153
- case 4 :
154
- nwarps = 4 ;
155
- rows_per_cuda_block = 2 ;
156
- break ;
157
- case 5 :
158
- case 6 :
159
- case 7 :
160
- case 8 :
161
- nwarps = 2 ;
162
- rows_per_cuda_block = 2 ;
163
- break ;
164
- default :
165
- GGML_ABORT (" fatal error" );
166
- break ;
167
- }
168
- }
169
-
170
- const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1 ) / rows_per_cuda_block;
171
- const dim3 block_nums (nblocks, 1 , 1 );
172
- const dim3 block_dims (WARP_SIZE, nwarps, 1 );
222
+ CUDA_CHECK (cudaGetDevice (&device));
223
+ warp_size = ggml_cuda_info ().devices [device].warp_size ;
224
+ int table_id = get_device_table_id (ggml_cuda_info ().devices [device].cc );
173
225
174
226
switch (ncols_y) {
175
227
case 1 :
176
- mul_mat_vec_q<type, 1 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
228
+ {
229
+ constexpr int c_ncols_y = 1 ;
230
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
231
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
177
232
break ;
233
+ }
178
234
case 2 :
179
- mul_mat_vec_q<type, 2 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
235
+ {
236
+ constexpr int c_ncols_y = 2 ;
237
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
238
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
180
239
break ;
240
+ }
181
241
case 3 :
182
- mul_mat_vec_q<type, 3 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
242
+ {
243
+ constexpr int c_ncols_y = 3 ;
244
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
245
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
183
246
break ;
247
+ }
184
248
case 4 :
185
- mul_mat_vec_q<type, 4 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
249
+ {
250
+ constexpr int c_ncols_y = 4 ;
251
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
252
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
186
253
break ;
254
+ }
187
255
case 5 :
188
- mul_mat_vec_q<type, 5 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
256
+ {
257
+ constexpr int c_ncols_y = 5 ;
258
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
259
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
189
260
break ;
261
+ }
190
262
case 6 :
191
- mul_mat_vec_q<type, 6 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
263
+ {
264
+ constexpr int c_ncols_y = 6 ;
265
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
266
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
192
267
break ;
268
+ }
193
269
case 7 :
194
- mul_mat_vec_q<type, 7 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
270
+ {
271
+ constexpr int c_ncols_y = 7 ;
272
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
273
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
195
274
break ;
275
+ }
196
276
case 8 :
197
- mul_mat_vec_q<type, 8 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
277
+ {
278
+ constexpr int c_ncols_y = 8 ;
279
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
280
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
198
281
break ;
282
+ }
199
283
default :
200
284
GGML_ABORT (" fatal error" );
201
285
break ;
0 commit comments