@@ -47,36 +47,93 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
47
47
1 ;
48
48
}
49
49
50
+ static constexpr __device__ int get_device_table ()
51
+ {
52
+ #if defined(RDNA2) || defined(RDNA3)
53
+ return 1 ;
54
+ #else
55
+ return 0 ;
56
+ #endif // defined(RDNA2) || defined(RDNA3)
57
+ }
58
+
59
+ static __host__ int get_device_table (int cc)
60
+ {
61
+ if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3 (cc)) {
62
+ return 1 ;
63
+ }
64
+
65
+ return 0 ;
66
+ }
67
+
68
+ static constexpr int calc_nwarps (int ncols_y, int table_id)
69
+ {
70
+ if (table_id == 0 )
71
+ {
72
+ switch (ncols_y) {
73
+ case 1 :
74
+ case 2 :
75
+ case 3 :
76
+ case 4 :
77
+ return 2 ;
78
+ case 5 :
79
+ case 6 :
80
+ case 7 :
81
+ case 8 :
82
+ return 4 ;
83
+ default :
84
+ return 1 ;
85
+ }
86
+ } else {
87
+ return 1 ;
88
+ }
89
+ }
90
+
91
+ static constexpr int calc_rows_per_block (int ncols_y, int table_id)
92
+ {
93
+ if (table_id == 0 ) {
94
+ switch (ncols_y) {
95
+ case 1 :
96
+ return 1 ;
97
+ case 2 :
98
+ case 3 :
99
+ case 4 :
100
+ case 5 :
101
+ case 6 :
102
+ case 7 :
103
+ case 8 :
104
+ return 2 ;
105
+ default :
106
+ return 1 ;
107
+ }
108
+ } else {
109
+ return 1 ;
110
+ }
111
+ }
112
+
50
113
template <ggml_type type, int ncols_y>
51
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52
114
// tell the compiler to use as many registers as it wants, see nwarps definition below
53
- __launch_bounds__ ((ncols_y <= 4 ? 4 : 2 )*WARP_SIZE, 1)
54
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
115
+ __launch_bounds__ (calc_nwarps(ncols_y, get_device_table())*ggml_cuda_get_physical_warp_size(), 4)
55
116
static __global__ void mul_mat_vec_q(
56
117
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57
118
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
58
119
59
120
constexpr int qk = ggml_cuda_type_traits<type>::qk;
60
121
constexpr int qi = ggml_cuda_type_traits<type>::qi;
61
122
constexpr int vdr = get_vdr_mmvq (type);
123
+ constexpr int table_id = get_device_table ();
124
+ constexpr int nwarps = calc_nwarps (ncols_y, table_id);
125
+ constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_y, table_id);
126
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
62
127
63
128
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda (type);
64
129
65
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66
- constexpr int nwarps = 1 ;
67
- constexpr int rows_per_cuda_block = 1 ;
68
- #else
69
- constexpr int nwarps = ncols_y <= 4 ? 4 : 2 ;
70
- constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2 ;
71
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72
-
73
- const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
130
+ const int tid = warp_size*threadIdx .y + threadIdx .x ;
74
131
const int row0 = rows_per_cuda_block*blockIdx .x ;
75
132
const int blocks_per_row_x = ncols_x / qk;
76
133
const int blocks_per_col_y = nrows_y / QK8_1;
77
- constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
134
+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
78
135
79
- // partial sum for each thread
136
+ // partial sum for each thread
80
137
float tmp[ncols_y][rows_per_cuda_block] = {0 .0f };
81
138
82
139
const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +153,7 @@ static __global__ void mul_mat_vec_q(
96
153
}
97
154
}
98
155
99
- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][WARP_SIZE ];
156
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][warp_size ];
100
157
if (threadIdx .y > 0 ) {
101
158
#pragma unroll
102
159
for (int j = 0 ; j < ncols_y; ++j) {
@@ -120,7 +177,7 @@ static __global__ void mul_mat_vec_q(
120
177
for (int l = 0 ; l < nwarps-1 ; ++l) {
121
178
tmp[j][i] += tmp_shared[l][j][i][threadIdx .x ];
122
179
}
123
- tmp[j][i] = warp_reduce_sum (tmp[j][i]);
180
+ tmp[j][i] = warp_reduce_sum<warp_size> (tmp[j][i]);
124
181
}
125
182
126
183
if (threadIdx .x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx .x < nrows_dst)) {
@@ -129,73 +186,85 @@ static __global__ void mul_mat_vec_q(
129
186
}
130
187
}
131
188
189
+ static std::pair<dim3 , dim3 > calc_launch_params (const int ncols_y, const int nrows_x, const int warp_size, int table_id)
190
+ {
191
+ const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_y, table_id) - 1 ) / calc_rows_per_block (ncols_y, table_id);
192
+ const dim3 block_nums (nblocks, 1 , 1 );
193
+ const dim3 block_dims (warp_size, calc_nwarps (ncols_y, table_id), 1 );
194
+ return {block_nums, block_dims};
195
+ }
196
+
132
197
template <ggml_type type>
133
198
static void mul_mat_vec_q_cuda (
134
199
const void * vx, const void * vy, float * dst,
135
200
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
201
+ int device;
202
+ int warp_size;
136
203
137
204
GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
138
205
GGML_ASSERT (ncols_y <= MMVQ_MAX_BATCH_SIZE);
139
206
140
- int id = ggml_cuda_get_device ();
141
-
142
- int64_t nwarps = 1 ;
143
- int64_t rows_per_cuda_block = 1 ;
144
-
145
- if (ggml_cuda_info ().devices [id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146
- switch (ncols_y) {
147
- case 1 :
148
- nwarps = 4 ;
149
- rows_per_cuda_block = 1 ;
150
- break ;
151
- case 2 :
152
- case 3 :
153
- case 4 :
154
- nwarps = 4 ;
155
- rows_per_cuda_block = 2 ;
156
- break ;
157
- case 5 :
158
- case 6 :
159
- case 7 :
160
- case 8 :
161
- nwarps = 2 ;
162
- rows_per_cuda_block = 2 ;
163
- break ;
164
- default :
165
- GGML_ABORT (" fatal error" );
166
- break ;
167
- }
168
- }
169
-
170
- const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1 ) / rows_per_cuda_block;
171
- const dim3 block_nums (nblocks, 1 , 1 );
172
- const dim3 block_dims (WARP_SIZE, nwarps, 1 );
207
+ CUDA_CHECK (cudaGetDevice (&device));
208
+ warp_size = ggml_cuda_info ().devices [device].warp_size ;
209
+ int table_id = get_device_table (ggml_cuda_info ().devices [device].cc );
173
210
174
211
switch (ncols_y) {
175
212
case 1 :
176
- mul_mat_vec_q<type, 1 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
213
+ {
214
+ constexpr int c_ncols_y = 1 ;
215
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
216
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
177
217
break ;
218
+ }
178
219
case 2 :
179
- mul_mat_vec_q<type, 2 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
220
+ {
221
+ constexpr int c_ncols_y = 2 ;
222
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
223
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
180
224
break ;
225
+ }
181
226
case 3 :
182
- mul_mat_vec_q<type, 3 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
227
+ {
228
+ constexpr int c_ncols_y = 3 ;
229
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
230
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
183
231
break ;
232
+ }
184
233
case 4 :
185
- mul_mat_vec_q<type, 4 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
234
+ {
235
+ constexpr int c_ncols_y = 4 ;
236
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
237
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
186
238
break ;
239
+ }
187
240
case 5 :
188
- mul_mat_vec_q<type, 5 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
241
+ {
242
+ constexpr int c_ncols_y = 5 ;
243
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
244
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
189
245
break ;
246
+ }
190
247
case 6 :
191
- mul_mat_vec_q<type, 6 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
248
+ {
249
+ constexpr int c_ncols_y = 6 ;
250
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
251
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
192
252
break ;
253
+ }
193
254
case 7 :
194
- mul_mat_vec_q<type, 7 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
255
+ {
256
+ constexpr int c_ncols_y = 7 ;
257
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
258
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
195
259
break ;
260
+ }
196
261
case 8 :
197
- mul_mat_vec_q<type, 8 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
262
+ {
263
+ constexpr int c_ncols_y = 8 ;
264
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
265
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
198
266
break ;
267
+ }
199
268
default :
200
269
GGML_ABORT (" fatal error" );
201
270
break ;
0 commit comments