Skip to content

Commit 19df43a

Browse files
small refactor
1 parent 6ff96b1 commit 19df43a

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

ggml-cuda.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4297,12 +4297,13 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
42974297
}
42984298

42994299
static void quantize_row_q8_1_cuda(
4300-
const float * x, void * vy, const int kx, const int ky, const int kx_padded, const int nchannels, cudaStream_t stream) {
4300+
const float * x, void * vy, const int kx, const int ky, const int kx_padded, const int nchannels,
4301+
const int row_stride, const int channel_stride, cudaStream_t stream) {
43014302

43024303
const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
43034304
const dim3 num_blocks(block_num_x, ky, nchannels);
43044305
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
4305-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded, ky, kx, ky*kx);
4306+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded, ky, row_stride, channel_stride);
43064307
}
43074308

43084309
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -5558,7 +5559,7 @@ inline void ggml_cuda_op_mul_mat_q(
55585559
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
55595560
size_t as;
55605561
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*nchannels*sizeof(block_q8_1)/QK8_1, &as);
5561-
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, nchannels, cudaStream_main);
5562+
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, nchannels, ne10, ne10*ne11, cudaStream_main);
55625563

55635564
// const int row_stride = nb01 / ggml_type_size(src0->type);
55645565
const int row_stride = src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU &&
@@ -5712,7 +5713,7 @@ inline void ggml_cuda_op_mul_mat_vec(
57125713
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
57135714
size_t as;
57145715
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne02*sizeof(block_q8_1)/QK8_1, &as);
5715-
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, 1, padded_row_size, ne02, cudaStream_main);
5716+
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, 1, padded_row_size, ne02, ne10, ne10*1, cudaStream_main);
57165717

57175718
const int row_delta = nb01 / ggml_type_size(src0->type);
57185719
const int channel_delta = nb02 / ggml_type_size(src0->type);

0 commit comments

Comments
 (0)