Skip to content

Commit 7a11eb3

Browse files
authored
cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X (#8800)
* cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X * update asserts * only use dmmv for supported types * add test
1 parent c8a0090 commit 7a11eb3

File tree

4 files changed

+22
-11
lines changed

4 files changed

+22
-11
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,10 +1885,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18851885
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
18861886
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
18871887

1888-
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
1888+
bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
18891889
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1890-
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
1891-
&& src1->ne[1] == 1;
1890+
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
18921891
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
18931892
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
18941893
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;

ggml/src/ggml-cuda/dmmv.cu

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
500500
}
501501

502502
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
503-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
503+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
504504
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
505505
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
506506
const dim3 block_nums(block_num_y, 1, 1);
@@ -510,7 +510,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
510510
}
511511

512512
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
513-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
513+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
514514
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
515515
const dim3 block_nums(block_num_y, 1, 1);
516516
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -519,7 +519,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
519519
}
520520

521521
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
522-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
522+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
523523
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
524524
const dim3 block_nums(block_num_y, 1, 1);
525525
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -528,7 +528,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
528528
}
529529

530530
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
531-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
531+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
532532
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
533533
const dim3 block_nums(block_num_y, 1, 1);
534534
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -537,7 +537,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
537537
}
538538

539539
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
540-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
540+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
541541
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
542542
const dim3 block_nums(block_num_y, 1, 1);
543543
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -588,7 +588,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
588588
}
589589

590590
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
591-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
591+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
592592
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
593593
const dim3 block_nums(block_num_y, 1, 1);
594594
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -672,3 +672,12 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
672672
GGML_UNUSED(src1_ncols);
673673
GGML_UNUSED(src1_padded_row_size);
674674
}
675+
676+
bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
677+
return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
678+
src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
679+
src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
680+
src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
681+
src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
682+
src0_type == GGML_TYPE_F16;
683+
}

ggml/src/ggml-cuda/dmmv.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
1616
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1717
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
1818
const int64_t src1_padded_row_size, cudaStream_t stream);
19+
20+
bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);

tests/test-backend-ops.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,7 @@ struct test_cpy : public test_case {
804804

805805
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
806806
std::array<int64_t, 4> ne = {10, 10, 10, 1},
807-
std::array<int64_t, 4> permute = {0, 0, 0, 0},
808-
bool _dst_use_permute = false)
807+
std::array<int64_t, 4> permute = {0, 0, 0, 0})
809808
: type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
810809
_src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
811810

@@ -2269,6 +2268,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22692268

22702269
for (ggml_type type_a : other_types) {
22712270
for (ggml_type type_b : {GGML_TYPE_F32}) {
2271+
2272+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), { 1, 1}, {1, 1}));
22722273
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
22732274
}
22742275
}

0 commit comments

Comments
 (0)