@@ -1565,6 +1565,43 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
1565
1565
return vec_dot_q5_1_q8_1_impl (qs, qh, ui0, ui1, bq5_1->dm , bq8_1->ds );
1566
1566
}
1567
1567
1568
+ static __device__ __forceinline__ void allocate_tiles_q5_1 (int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) {
1569
+
1570
+ __shared__ int tile_x_ql[(2 *WARP_SIZE) * (WARP_SIZE + 1 )];
1571
+ __shared__ int tile_x_qh[(2 *WARP_SIZE) * (WARP_SIZE/QI5_1)];
1572
+ __shared__ half2 tile_x_dm[(2 *WARP_SIZE) * (WARP_SIZE/QI5_1)];
1573
+
1574
+ *x_ql = tile_x_ql;
1575
+ *x_qh = tile_x_qh;
1576
+ *x_dm = tile_x_dm;
1577
+ }
1578
+
1579
+ static __device__ __forceinline__ void load_tiles_q5_1 (
1580
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1581
+ int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
1582
+
1583
+ const int kbx = k / QI5_1;
1584
+ const int kqsx = k % QI5_1;
1585
+
1586
+ const block_q5_1 * bx = ((block_q5_1 *) vx) + i*blocks_per_row + kbx;
1587
+
1588
+ x_ql[i * (WARP_SIZE + 1 ) + k] = get_int_from_uint8 (bx->qs , kqsx);
1589
+ x_qh[i * (WARP_SIZE / QI5_1) + kbx] = get_int_from_uint8 (bx->qh , 0 );
1590
+ x_dm[i * (WARP_SIZE / QI5_1) + kbx] = bx->dm ;
1591
+ }
1592
+
1593
+ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat (
1594
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc,
1595
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
1596
+
1597
+ const int kyqs = k % (QI8_1/2 ) + QI8_1 * (k / (QI8_1/2 ));
1598
+ const int index_bx = i * (WARP_SIZE/QI5_0) + k/QI5_0;
1599
+
1600
+ return vec_dot_q5_1_q8_1_impl (
1601
+ x_ql[i * (WARP_SIZE + 1 ) + k], x_qh[index_bx] >> (4 * (k % QI5_1)), y_qs[j * (2 *WARP_SIZE) + kyqs],
1602
+ y_qs[j * (2 *WARP_SIZE) + kyqs + (QI8_1/2 )], x_dm[index_bx], y_ds[j * (2 *WARP_SIZE/QI8_1) + 2 *k/QI8_1]);
1603
+ }
1604
+
1568
1605
static __device__ __forceinline__ float vec_dot_q8_0_q8_1 (
1569
1606
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
1570
1607
@@ -2602,6 +2639,14 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(const void * vx, const void * vy, float
2602
2639
mul_mat_q<QK5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0, vec_dot_q5_0_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2603
2640
}
2604
2641
2642
+ static void ggml_mul_mat_q5_1_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2643
+ const int block_num_x = (nrows_x + 2 *WARP_SIZE - 1 ) / (2 *WARP_SIZE);
2644
+ const int block_num_y = (ncols_y + WARP_SIZE - 1 ) / WARP_SIZE;
2645
+ const dim3 block_nums (block_num_x, block_num_y, 1 );
2646
+ const dim3 block_dims (WARP_SIZE, WARP_SIZE/4 , 1 );
2647
+ mul_mat_q<QK5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1, vec_dot_q5_1_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2648
+ }
2649
+
2605
2650
static void ggml_mul_mat_p021_f16_f32_cuda (const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
2606
2651
const dim3 block_nums (1 , nrows_x, nchannels_x);
2607
2652
const dim3 block_dims (WARP_SIZE, 1 , 1 );
@@ -3071,6 +3116,9 @@ inline void ggml_cuda_op_mul_mat_q(
3071
3116
case GGML_TYPE_Q5_0:
3072
3117
ggml_mul_mat_q5_0_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
3073
3118
break ;
3119
+ case GGML_TYPE_Q5_1:
3120
+ ggml_mul_mat_q5_1_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
3121
+ break ;
3074
3122
default :
3075
3123
GGML_ASSERT (false );
3076
3124
break ;
@@ -3820,7 +3868,8 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
3820
3868
if (src1->ne [1 ] == 1 && src0->ne [0 ] % GGML_CUDA_DMMV_X == 0 ) {
3821
3869
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul_mat_vec, false , false );
3822
3870
} else {
3823
- if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0) {
3871
+ if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 ||
3872
+ src0->type == GGML_TYPE_Q5_1) {
3824
3873
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul_mat_q, false , false );
3825
3874
} else {
3826
3875
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true , false );
0 commit comments