Skip to content

Commit c0daa66

Browse files
works
1 parent 19df43a commit c0daa66

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

ggml-cuda.cu

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5522,6 +5522,7 @@ inline void ggml_cuda_op_rms_norm(
55225522
(void) i1;
55235523
}
55245524

5525+
template <bool buffers_contiguous>
55255526
inline void ggml_cuda_op_mul_mat_q(
55265527
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
55275528
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -5534,15 +5535,18 @@ inline void ggml_cuda_op_mul_mat_q(
55345535
const int64_t ne00 = src0->ne[0];
55355536
const int64_t ne02 = src0->ne[2];
55365537

5537-
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
5538-
const size_t nb01 = src0->nb[1];
5539-
const size_t nb02 = src0->nb[2];
5540-
55415538
const int64_t ne10 = src1->ne[0];
55425539
const int64_t ne11 = src1->ne[1];
55435540

55445541
const int64_t ne0 = dst->ne[0];
55455542

5543+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
5544+
const size_t nb01 = src0->nb[1];
5545+
const size_t nb02 = src0->nb[2];
5546+
5547+
const size_t nb11 = src1->nb[1];
5548+
const size_t nb12 = src1->nb[2];
5549+
55465550
const int64_t i01_diff = i01_high - i01_low;
55475551

55485552
int id;
@@ -5552,19 +5556,19 @@ inline void ggml_cuda_op_mul_mat_q(
55525556
// nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
55535557
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
55545558

5555-
const int nchannels = src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU &&
5556-
dst->backend == GGML_BACKEND_GPU && ggml_is_contiguous(src1) ? ne02 : 1;
5559+
const int nchannels = buffers_contiguous ? 1 : ne02;
55575560

55585561
const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
55595562
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
55605563
size_t as;
55615564
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*nchannels*sizeof(block_q8_1)/QK8_1, &as);
5562-
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, nchannels, ne10, ne10*ne11, cudaStream_main);
5565+
const int64_t src1_row_stride = buffers_contiguous ? ne10 : nb11 / sizeof(float);
5566+
const int64_t src1_channel_stride = buffers_contiguous ? ne10*ne11 : nb12 / sizeof(float);
5567+
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, nchannels,
5568+
src1_row_stride, src1_channel_stride, cudaStream_main);
55635569

5564-
// const int row_stride = nb01 / ggml_type_size(src0->type);
5565-
const int row_stride = src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU &&
5566-
dst->backend == GGML_BACKEND_GPU && ggml_is_contiguous(src1) ? nb01 / ggml_type_size(src0->type) : ne10 / ggml_blck_size(src0->type);
5567-
const int channel_stride_x = nb02 / ggml_type_size(src0->type);
5570+
const int row_stride = buffers_contiguous ? ne10 / ggml_blck_size(src0->type) : nb01 / ggml_type_size(src0->type);
5571+
const int channel_stride_x = buffers_contiguous ? ne10*ne11 / ggml_blck_size(src0->type) : nb02 / ggml_type_size(src0->type);
55685572
const int channel_stride_y = padded_row_size*ne11 / QK8_1;
55695573

55705574
switch (src0->type) {
@@ -5681,6 +5685,9 @@ inline void ggml_cuda_op_mul_mat_vec(
56815685
const int64_t nb01 = src0->nb[1];
56825686
const int64_t nb02 = src0->nb[2];
56835687

5688+
const int64_t nb11 = src1->nb[1];
5689+
const int64_t nb12 = src1->nb[2];
5690+
56845691
const int64_t nrows = i01_high - i01_low;
56855692

56865693
#ifdef GGML_CUDA_FORCE_DMMV
@@ -5713,7 +5720,9 @@ inline void ggml_cuda_op_mul_mat_vec(
57135720
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
57145721
size_t as;
57155722
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne02*sizeof(block_q8_1)/QK8_1, &as);
5716-
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, 1, padded_row_size, ne02, ne10, ne10*1, cudaStream_main);
5723+
const int64_t row_stride = src1->backend == GGML_BACKEND_CPU ? ne10 : nb11 / sizeof(float);
5724+
const int64_t channel_stride = src1->backend == GGML_BACKEND_CPU ? ne10*1 : nb12 / sizeof(float);
5725+
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, 1, padded_row_size, ne02, row_stride, channel_stride, cudaStream_main);
57175726

57185727
const int row_delta = nb01 / ggml_type_size(src0->type);
57195728
const int channel_delta = nb02 / ggml_type_size(src0->type);
@@ -6433,7 +6442,7 @@ void ggml_cuda_mul_mat_nc(const ggml_tensor * src0, const ggml_tensor * src1, gg
64336442
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
64346443
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
64356444

6436-
ggml_cuda_op_mul_mat_q(src0, src1, dst, src0_ddq, nullptr, src1_ddf, dst_ddf, 0, 0, ne01, 0, cudaStream_main);
6445+
ggml_cuda_op_mul_mat_q<false>(src0, src1, dst, src0_ddq, nullptr, src1_ddf, dst_ddf, 0, 0, ne01, 0, cudaStream_main);
64376446
CUDA_CHECK(cudaGetLastError());
64386447
}
64396448

@@ -6534,10 +6543,10 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
65346543
}
65356544

65366545
if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
6537-
if (all_on_device && src0->backend != GGML_BACKEND_GPU_SPLIT && ggml_is_contiguous(src1)) {
6546+
if (all_on_device && src0->backend != GGML_BACKEND_GPU_SPLIT) {
65386547
ggml_cuda_mul_mat_nc(src0, src1, dst);
65396548
} else {
6540-
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
6549+
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q<true>, false, false);
65416550
}
65426551
} else {
65436552
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);

0 commit comments

Comments
 (0)