Skip to content

Commit e0c926d

Browse files
q5_0 works
1 parent 274b0d7 commit e0c926d

File tree

1 file changed

+51
-4
lines changed

1 file changed

+51
-4
lines changed

ggml-cuda.cu

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
13691369
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
13701370

13711371
return vec_dot_q4_0_q8_1_impl(
1372-
x_ql[i * WARP_SIZE + i + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
1372+
x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
13731373
x_dm[i * (WARP_SIZE/QI4_0) + k/QI4_0].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
13741374
}
13751375

@@ -1441,7 +1441,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
14411441
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
14421442

14431443
return vec_dot_q4_1_q8_1_impl(
1444-
x_ql[i * WARP_SIZE + i + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
1444+
x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
14451445
x_dm[i * (WARP_SIZE/QI4_1) + k/QI4_1], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
14461446
}
14471447

@@ -1482,6 +1482,43 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
14821482
return vec_dot_q5_0_q8_1_impl(qs, qh, ui0, ui1, bq5_0->d, bq8_1->ds);
14831483
}
14841484

1485+
static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) {
1486+
1487+
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
1488+
__shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/QI5_0)];
1489+
__shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI5_0)];
1490+
1491+
*x_ql = tile_x_ql;
1492+
*x_qh = tile_x_qh;
1493+
*x_dm = tile_x_d;
1494+
}
1495+
1496+
static __device__ __forceinline__ void load_tiles_q5_0(
1497+
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1498+
int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
1499+
1500+
const int kbx = k / QI5_0;
1501+
const int kqsx = k % QI5_0;
1502+
1503+
const block_q5_0 * bx = ((block_q5_0 *) vx) + i*blocks_per_row + kbx;
1504+
1505+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
1506+
x_qh[i * (WARP_SIZE / QI5_0) + kbx] = get_int_from_uint8(bx->qh, 0);
1507+
x_dm[i * (WARP_SIZE / QI5_0) + kbx].x = bx->d;
1508+
}
1509+
1510+
static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
1511+
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc,
1512+
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
1513+
1514+
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
1515+
const int index_bx = i * (WARP_SIZE/QI5_0) + k/QI5_0;
1516+
1517+
return vec_dot_q5_0_q8_1_impl(
1518+
x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_0)), y_qs[j * (2*WARP_SIZE) + kyqs],
1519+
y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], x_dm[index_bx].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
1520+
}
1521+
14851522
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
14861523
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
14871524

@@ -2547,6 +2584,14 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(const void * vx, const void * vy, float
25472584
mul_mat_q<QK4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1, vec_dot_q4_1_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
25482585
}
25492586

2587+
static void ggml_mul_mat_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2588+
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
2589+
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
2590+
const dim3 block_nums(block_num_x, block_num_y, 1);
2591+
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
2592+
mul_mat_q<QK5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0, vec_dot_q5_0_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2593+
}
2594+
25502595
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
25512596
const dim3 block_nums(1, nrows_x, nchannels_x);
25522597
const dim3 block_dims(WARP_SIZE, 1, 1);
@@ -3004,6 +3049,9 @@ inline void ggml_cuda_op_mul_mat_q(
30043049
case GGML_TYPE_Q4_1:
30053050
ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
30063051
break;
3052+
case GGML_TYPE_Q5_0:
3053+
ggml_mul_mat_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
3054+
break;
30073055
default:
30083056
GGML_ASSERT(false);
30093057
break;
@@ -3753,8 +3801,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
37533801
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
37543802
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
37553803
} else {
3756-
// if (src0->type == GGML_TYPE_Q4_0) {
3757-
if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1) {
3804+
if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0) {
37583805
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
37593806
} else {
37603807
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);

0 commit comments

Comments
 (0)