Skip to content

Commit 3bdc4cd

Browse files
CUDA: mul_mat_vec_q tiling, refactor mul mat logic (#5434)
* CUDA: mul_mat_vec_q tiling, refactor mul mat logic Co-authored-by: slaren <[email protected]> --------- Co-authored-by: slaren <[email protected]>
1 parent 2891c8a commit 3bdc4cd

File tree

1 file changed

+140
-107
lines changed

1 file changed

+140
-107
lines changed

ggml-cuda.cu

Lines changed: 140 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@
150150
#define CUDA_USE_TENSOR_CORES
151151
#endif
152152

153-
// max batch size to use MMQ kernels when tensor cores are available
154-
#define MMQ_MAX_BATCH_SIZE 32
153+
#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
154+
#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
155155

156156
#if defined(GGML_USE_HIPBLAS)
157157
#define __CUDA_ARCH__ 1300
@@ -5310,51 +5310,59 @@ template <bool need_check> static __global__ void
53105310
#endif // __CUDA_ARCH__ >= CC_VOLTA
53115311
}
53125312

5313-
#define MMVQ_NWARPS_NVIDIA 4
5314-
#define MMVQ_NWARPS_AMD_RDNA2 1
5315-
#define MMVQ_NWARPS_AMD_OLD 4
5316-
5317-
template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5313+
template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
53185314
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
5319-
__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
5315+
// tell the compiler to use as many registers as it wants, see nwarps definition below
5316+
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
53205317
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
53215318
static __global__ void mul_mat_vec_q(
53225319
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5323-
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
5320+
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
53245321

5325-
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
5322+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
5323+
constexpr int nwarps = 1;
5324+
constexpr int rows_per_cuda_block = 1;
5325+
#else
5326+
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
5327+
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
5328+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
53265329

5327-
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
5328-
const int row = blockIdx.x;
5329-
5330-
const int blocks_per_row_x = ncols_x / qk;
5331-
const int blocks_per_col_y = nrows_y / QK8_1;
5332-
const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
5330+
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
5331+
const int row0 = rows_per_cuda_block*blockIdx.x;
5332+
const int blocks_per_row_x = ncols_x / qk;
5333+
const int blocks_per_col_y = nrows_y / QK8_1;
5334+
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
53335335

53345336
// partial sum for each thread
5335-
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
5337+
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
53365338

53375339
const block_q_t * x = (const block_q_t *) vx;
53385340
const block_q8_1 * y = (const block_q8_1 *) vy;
53395341

5340-
for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
5341-
const int ibx = row*blocks_per_row_x + i; // x block index
5342-
5343-
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
5342+
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
5343+
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
53445344

5345-
const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
5345+
// x block quant index when casting the quants to int
5346+
const int kqs = vdr * (tid % (qi/vdr));
53465347

53475348
#pragma unroll
53485349
for (int j = 0; j < ncols_y; ++j) {
5349-
tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
5350+
#pragma unroll
5351+
for (int i = 0; i < rows_per_cuda_block; ++i) {
5352+
tmp[j][i] += vec_dot_q_cuda(
5353+
&x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
5354+
}
53505355
}
53515356
}
53525357

5353-
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
5358+
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
53545359
if (threadIdx.y > 0) {
53555360
#pragma unroll
53565361
for (int j = 0; j < ncols_y; ++j) {
5357-
tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
5362+
#pragma unroll
5363+
for (int i = 0; i < rows_per_cuda_block; ++i) {
5364+
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
5365+
}
53585366
}
53595367
}
53605368
__syncthreads();
@@ -5366,13 +5374,16 @@ static __global__ void mul_mat_vec_q(
53665374
#pragma unroll
53675375
for (int j = 0; j < ncols_y; ++j) {
53685376
#pragma unroll
5369-
for (int i = 0; i < nwarps-1; ++i) {
5370-
tmp[j] += tmp_shared[i][j][threadIdx.x];
5377+
for (int i = 0; i < rows_per_cuda_block; ++i) {
5378+
#pragma unroll
5379+
for (int l = 0; l < nwarps-1; ++l) {
5380+
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
5381+
}
5382+
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
53715383
}
5372-
tmp[j] = warp_reduce_sum(tmp[j]);
53735384

5374-
if (threadIdx.x == 0) {
5375-
dst[j*nrows_dst + row] = tmp[j];
5385+
if (threadIdx.x < rows_per_cuda_block) {
5386+
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
53765387
}
53775388
}
53785389
}
@@ -6851,65 +6862,75 @@ static void mul_mat_vec_q_cuda(
68516862
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
68526863

68536864
GGML_ASSERT(ncols_x % qk == 0);
6854-
GGML_ASSERT(ncols_y <= 4);
6865+
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
68556866

68566867
int id;
68576868
CUDA_CHECK(cudaGetDevice(&id));
68586869

6859-
int nwarps;
6860-
if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
6861-
nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
6862-
} else {
6863-
nwarps = MMVQ_NWARPS_NVIDIA;
6864-
}
6870+
int64_t nwarps = 1;
6871+
int64_t rows_per_cuda_block = 1;
68656872

6866-
const dim3 block_nums(nrows_x, 1, 1);
6867-
const dim3 block_dims(WARP_SIZE, nwarps, 1);
6868-
6869-
switch (nwarps) {
6870-
case 1: switch(ncols_y) {
6873+
if (g_device_caps[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
6874+
switch(ncols_y) {
68716875
case 1:
6872-
mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
6873-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6876+
nwarps = 4;
6877+
rows_per_cuda_block = 1;
68746878
break;
68756879
case 2:
6876-
mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
6877-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6878-
break;
68796880
case 3:
6880-
mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
6881-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6882-
break;
68836881
case 4:
6884-
mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
6885-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6886-
break;
6887-
default:
6888-
GGML_ASSERT(false);
6889-
break;
6890-
} break;
6891-
case 4: switch(ncols_y) {
6892-
case 1:
6893-
mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
6894-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6882+
nwarps = 4;
6883+
rows_per_cuda_block = 2;
68956884
break;
6896-
case 2:
6897-
mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
6898-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6899-
break;
6900-
case 3:
6901-
mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
6902-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6903-
break;
6904-
case 4:
6905-
mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
6906-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6885+
case 5:
6886+
case 6:
6887+
case 7:
6888+
case 8:
6889+
nwarps = 2;
6890+
rows_per_cuda_block = 2;
69076891
break;
69086892
default:
69096893
GGML_ASSERT(false);
69106894
break;
6911-
} break;
6895+
}
6896+
}
6897+
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
6898+
const dim3 block_nums(nblocks, 1, 1);
6899+
const dim3 block_dims(WARP_SIZE, nwarps, 1);
69126900

6901+
switch (ncols_y) {
6902+
case 1:
6903+
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
6904+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6905+
break;
6906+
case 2:
6907+
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
6908+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6909+
break;
6910+
case 3:
6911+
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
6912+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6913+
break;
6914+
case 4:
6915+
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
6916+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6917+
break;
6918+
case 5:
6919+
mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6920+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6921+
break;
6922+
case 6:
6923+
mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6924+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6925+
break;
6926+
case 7:
6927+
mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6928+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6929+
break;
6930+
case 8:
6931+
mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6932+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
6933+
break;
69136934
default:
69146935
GGML_ASSERT(false);
69156936
break;
@@ -9735,7 +9756,7 @@ static __global__ void k_compute_batched_ptrs(
97359756
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
97369757
}
97379758

9738-
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9759+
static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
97399760
GGML_ASSERT(!ggml_is_transposed(src0));
97409761
GGML_ASSERT(!ggml_is_transposed(src1));
97419762

@@ -9893,39 +9914,69 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
98939914

98949915
int64_t min_compute_capability = INT_MAX;
98959916

9917+
bool any_pascal_with_slow_fp16 = false;
98969918
if (split) {
98979919
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
98989920
auto & tensor_split = buft_ctx->tensor_split;
98999921
for (int id = 0; id < g_device_count; ++id) {
9900-
if (min_compute_capability > g_device_caps[id].cc && tensor_split[id] < (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) {
9922+
// skip devices that are not going to do any work:
9923+
if (tensor_split[id] >= (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) {
9924+
continue;
9925+
}
9926+
9927+
if (min_compute_capability > g_device_caps[id].cc) {
99019928
min_compute_capability = g_device_caps[id].cc;
99029929
}
9930+
if (g_device_caps[id].cc == 610) {
9931+
any_pascal_with_slow_fp16 = true;
9932+
}
99039933
}
99049934
} else {
9905-
min_compute_capability = g_device_caps[g_main_device].cc;
9935+
min_compute_capability = g_device_caps[g_main_device].cc;
9936+
any_pascal_with_slow_fp16 = g_device_caps[g_main_device].cc == 610;
99069937
}
99079938

9939+
// check data types and tensor shapes for custom matrix multiplication kernels:
9940+
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
9941+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
9942+
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
9943+
9944+
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
9945+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
9946+
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
9947+
9948+
bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type)
9949+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
9950+
99089951
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
99099952

99109953
const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
9911-
bool use_mul_mat_q = ggml_is_quantized(src0->type);
9954+
99129955
#ifdef CUDA_USE_TENSOR_CORES
99139956
use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
99149957
#endif // CUDA_USE_TENSOR_CORES
99159958

99169959
#else
99179960

9918-
const bool fp16_performance_good = min_compute_capability >= CC_VOLTA;
9919-
bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
9961+
// fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
9962+
const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
9963+
9964+
// mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
9965+
use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
9966+
use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A;
9967+
99209968
#ifdef CUDA_USE_TENSOR_CORES
99219969
// when tensor cores are available, use them for large batch size
99229970
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
9923-
use_mul_mat_q = use_mul_mat_q && !(fp16_performance_good && src1->ne[1] > MMQ_MAX_BATCH_SIZE);
9971+
use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
99249972
#endif // CUDA_USE_TENSOR_CORES
99259973

99269974
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
99279975

9928-
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
9976+
// if mmvq is available it's a better choice than dmmv:
9977+
#ifndef GGML_CUDA_FORCE_DMMV
9978+
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
9979+
#endif // GGML_CUDA_FORCE_DMMV
99299980

99309981
// debug helpers
99319982
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
@@ -9943,33 +9994,15 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
99439994
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
99449995
} else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
99459996
// KQ + KQV multi-batch
9946-
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
9947-
} else if (src0->type == GGML_TYPE_F32) {
9948-
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
9949-
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
9950-
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) {
9951-
#ifdef GGML_CUDA_FORCE_DMMV
9952-
const bool use_mul_mat_vec_q = false;
9953-
#else
9954-
const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
9955-
#endif // GGML_CUDA_FORCE_DMMV
9956-
9957-
if (use_mul_mat_vec_q) {
9958-
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
9959-
} else {
9960-
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
9961-
}
9962-
} else {
9963-
if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) {
9964-
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
9965-
} else if (use_mul_mat_q) {
9966-
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
9967-
} else {
9968-
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
9969-
}
9970-
}
9997+
ggml_cuda_mul_mat_batched_cublas(src0, src1, dst);
9998+
} else if (use_dequantize_mul_mat_vec) {
9999+
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
10000+
} else if (use_mul_mat_vec_q) {
10001+
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
10002+
} else if (use_mul_mat_q) {
10003+
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
997110004
} else {
9972-
GGML_ASSERT(false);
10005+
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
997310006
}
997410007
}
997510008

0 commit comments

Comments
 (0)