@@ -5310,45 +5310,64 @@ template <bool need_check> static __global__ void
5310
5310
#endif // __CUDA_ARCH__ >= CC_VOLTA
5311
5311
}
5312
5312
5313
- template <int ncols_y_template, int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5313
+ #define MMVQ_NWARPS_NVIDIA 4
5314
+ #define MMVQ_NWARPS_AMD 1
5315
+
5316
+ template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5317
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
5318
+ __launch_bounds__ (nwarps*WARP_SIZE, 1 ) // tells the compiler to use as many registers as it wants
5319
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
5314
5320
static __global__ void mul_mat_vec_q (
5315
5321
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5316
5322
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
5317
5323
5318
5324
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
5319
5325
5320
- const int row = blockIdx .x *blockDim .y + threadIdx .y ;
5321
-
5322
- if (row >= nrows_x) {
5323
- return ;
5324
- }
5326
+ const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
5327
+ const int row = blockIdx .x ;
5325
5328
5326
5329
const int blocks_per_row_x = ncols_x / qk;
5327
5330
const int blocks_per_col_y = nrows_y / QK8_1;
5328
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
5331
+ const int blocks_per_iter = vdr * nwarps* WARP_SIZE / qi;
5329
5332
5330
5333
// partial sum for each thread
5331
5334
float tmp[ncols_y_template != 0 ? ncols_y_template : 8 ] = {0 .0f };
5332
5335
5333
5336
const block_q_t * x = (const block_q_t *) vx;
5334
5337
const block_q8_1 * y = (const block_q8_1 *) vy;
5335
5338
5336
- for (int i = threadIdx . x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp ) {
5339
+ for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter ) {
5337
5340
const int ibx = row*blocks_per_row_x + i; // x block index
5338
5341
5339
5342
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
5340
5343
5341
- const int iqs = vdr * (threadIdx . x % (qi/vdr)); // x block quant index when casting the quants to int
5344
+ const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
5342
5345
5343
5346
#pragma unroll
5344
5347
for (int j = 0 ; j < ncols_y; ++j) {
5345
5348
tmp[j] += vec_dot_q_cuda (&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
5346
5349
}
5347
5350
}
5348
5351
5352
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y_template != 0 ? ncols_y_template : 8 ][WARP_SIZE];
5353
+ if (threadIdx .y > 0 ) {
5354
+ #pragma unroll
5355
+ for (int j = 0 ; j < ncols_y; ++j) {
5356
+ tmp_shared[threadIdx .y -1 ][j][threadIdx .x ] = tmp[j];
5357
+ }
5358
+ }
5359
+ __syncthreads ();
5360
+ if (threadIdx .y > 0 ) {
5361
+ return ;
5362
+ }
5363
+
5349
5364
// sum up partial sums and write back result
5350
5365
#pragma unroll
5351
5366
for (int j = 0 ; j < ncols_y; ++j) {
5367
+ #pragma unroll
5368
+ for (int i = 0 ; i < nwarps-1 ; ++i) {
5369
+ tmp[j] += tmp_shared[i][j][threadIdx .x ];
5370
+ }
5352
5371
tmp[j] = warp_reduce_sum (tmp[j]);
5353
5372
5354
5373
if (threadIdx .x == 0 ) {
@@ -6833,42 +6852,49 @@ static void mul_mat_vec_q_cuda(
6833
6852
GGML_ASSERT (ncols_x % qk == 0 );
6834
6853
GGML_ASSERT (ncols_y <= 4 );
6835
6854
6836
- const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6837
- const dim3 block_nums (block_num_y, 1 , 1 );
6838
- const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6839
- switch (ncols_y) {
6840
- case 1 :
6841
- mul_mat_vec_q<1 , qk, qi, block_q_t , vdr, vec_dot>
6855
+ int id;
6856
+ CUDA_CHECK (cudaGetDevice (&id));
6857
+
6858
+ const int nwarps = g_device_caps[id].cc >= CC_OFFSET_AMD ? MMVQ_NWARPS_AMD : MMVQ_NWARPS_NVIDIA;
6859
+
6860
+ const dim3 block_nums (nrows_x, 1 , 1 );
6861
+ const dim3 block_dims (WARP_SIZE, nwarps, 1 );
6862
+
6863
+ const int32_t config = ncols_y | (nwarps << 16 );
6864
+
6865
+ switch (config) {
6866
+ case 0x00010001 :
6867
+ mul_mat_vec_q<1 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6868
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6869
+ break ;
6870
+ case 0x00010002 :
6871
+ mul_mat_vec_q<1 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6872
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6873
+ break ;
6874
+ case 0x00010003 :
6875
+ mul_mat_vec_q<1 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6876
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6877
+ break ;
6878
+ case 0x00010004 :
6879
+ mul_mat_vec_q<1 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6880
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6881
+ break ;
6882
+ case 0x00040001 :
6883
+ mul_mat_vec_q<4 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6842
6884
<<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6843
6885
break ;
6844
- case 2 :
6845
- mul_mat_vec_q<2 , qk, qi, block_q_t , vdr, vec_dot>
6886
+ case 0x00040002 :
6887
+ mul_mat_vec_q<4 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6846
6888
<<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6847
6889
break ;
6848
- case 3 :
6849
- mul_mat_vec_q<3 , qk, qi, block_q_t , vdr, vec_dot>
6890
+ case 0x00040003 :
6891
+ mul_mat_vec_q<4 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6850
6892
<<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6851
6893
break ;
6852
- case 4 :
6853
- mul_mat_vec_q<4 , qk, qi, block_q_t , vdr, vec_dot>
6894
+ case 0x00040004 :
6895
+ mul_mat_vec_q<4 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6854
6896
<<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6855
6897
break ;
6856
- // case 5:
6857
- // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6859
- // break;
6860
- // case 6:
6861
- // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6863
- // break;
6864
- // case 7:
6865
- // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6867
- // break;
6868
- // case 8:
6869
- // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6871
- // break;
6872
6898
default :
6873
6899
GGML_ASSERT (false );
6874
6900
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
0 commit comments