Skip to content

Commit 1f8bbf1

Browse files
mmqv fixed
1 parent ae3b1ab commit 1f8bbf1

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

ggml-cuda.cu

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5730,47 +5730,49 @@ inline void ggml_cuda_op_mul_mat_vec(
57305730
#endif // GGML_CUDA_FORCE_DMMV
57315731

57325732
if (use_mul_mat_vec_q) {
5733+
const int nchannels = buffers_contiguous ? 1 : ne02;
5734+
57335735
const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
57345736
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
57355737
size_t as;
5736-
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne02*sizeof(block_q8_1)/QK8_1, &as);
5737-
const int64_t row_stride_q = src1->backend == GGML_BACKEND_CPU ? ne10 : nb11 / sizeof(float);
5738-
const int64_t channel_stride_q = src1->backend == GGML_BACKEND_CPU ? ne10*1 : nb12 / sizeof(float);
5739-
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, 1, padded_row_size, ne02, row_stride_q, channel_stride_q, cudaStream_main);
5738+
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*nchannels*sizeof(block_q8_1)/QK8_1, &as);
5739+
const int64_t row_stride_q = buffers_contiguous ? ne10 : nb11 / sizeof(float);
5740+
const int64_t channel_stride_q = buffers_contiguous ? ne10*1 : nb12 / sizeof(float);
5741+
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, 1, padded_row_size, nchannels, row_stride_q, channel_stride_q, cudaStream_main);
57405742

5741-
const int row_stride_x = nb01 / ggml_type_size(src0->type);
5742-
const int channel_stride_x = nb02 / ggml_type_size(src0->type);
5743+
const int row_stride_x = buffers_contiguous ? ne00 / ggml_blck_size(src0->type) : nb01 / ggml_type_size(src0->type);
5744+
const int channel_stride_x = buffers_contiguous ? ne00*1 / ggml_blck_size(src0->type) : nb02 / ggml_type_size(src0->type);
57435745
const int channel_stride_y = padded_row_size / QK8_1;
57445746
switch (src0->type) {
57455747
case GGML_TYPE_Q4_0:
5746-
mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5748+
mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57475749
break;
57485750
case GGML_TYPE_Q4_1:
5749-
mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5751+
mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57505752
break;
57515753
case GGML_TYPE_Q5_0:
5752-
mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5754+
mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57535755
break;
57545756
case GGML_TYPE_Q5_1:
5755-
mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5757+
mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57565758
break;
57575759
case GGML_TYPE_Q8_0:
5758-
mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5760+
mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57595761
break;
57605762
case GGML_TYPE_Q2_K:
5761-
mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5763+
mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57625764
break;
57635765
case GGML_TYPE_Q3_K:
5764-
mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5766+
mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57655767
break;
57665768
case GGML_TYPE_Q4_K:
5767-
mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5769+
mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57685770
break;
57695771
case GGML_TYPE_Q5_K:
5770-
mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5772+
mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57715773
break;
57725774
case GGML_TYPE_Q6_K:
5773-
mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
5775+
mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, nchannels, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main);
57745776
break;
57755777
default:
57765778
GGML_ASSERT(false);
@@ -5779,7 +5781,7 @@ inline void ggml_cuda_op_mul_mat_vec(
57795781

57805782
ggml_cuda_pool_free(src1_q8_1, as);
57815783
} else {
5782-
GGML_ASSERT(buffers_contiguous || ne02 == 1);
5784+
GGML_ASSERT(buffers_contiguous);
57835785

57845786
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
57855787
#ifdef GGML_CUDA_F16
@@ -6548,7 +6550,8 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
65486550
}
65496551

65506552
// no quantized non-contiguous support for lower CC kernels implemented
6551-
const bool nc_okay = src0->type == GGML_TYPE_F16 || g_compute_capabilities[g_main_device] >= MIN_CC_DP4A;
6553+
// const bool nc_okay = src0->type == GGML_TYPE_F16 || g_compute_capabilities[g_main_device] >= MIN_CC_DP4A;
6554+
const bool nc_okay = false;
65526555

65536556
if (all_on_device && nc_okay && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
65546557
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);

0 commit comments

Comments
 (0)