@@ -1220,7 +1220,6 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, cons
1220
1220
sumi = __dp4a (vi1, ui1, sumi);
1221
1221
1222
1222
return sumi*d;
1223
-
1224
1223
}
1225
1224
1226
1225
static __device__ __forceinline__ float vec_dot_q4_1_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
@@ -1241,7 +1240,37 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, cons
1241
1240
sumi = __dp4a (vi1, ui1, sumi);
1242
1241
1243
1242
return sumi*d + m*s / QI4_1;
1243
+ }
1244
+
1245
+ static __device__ __forceinline__ float vec_dot_q5_0_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1246
+ const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
1247
+
1248
+ int qs;
1249
+ memcpy (&qs, &bq5_0->qs [sizeof (int ) * (iqs + 0 )], sizeof (int ));
1250
+ const int qh0 = bq5_0->qh [iqs/2 + 0 ] >> 4 *(iqs%2 );
1251
+ const int qh1 = bq5_0->qh [iqs/2 + 2 ] >> 4 *(iqs%2 );
1252
+ const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
1253
+ const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI4_0)]);
1254
+
1255
+ const float d = bq5_0->d * bq8_1->d ;
1256
+
1257
+ int vi0 = (qs >> 0 ) & 0x0F0F0F0F ;
1258
+ vi0 |= (qh0 << 4 ) & 0x00000010 ;
1259
+ vi0 |= (qh0 << 11 ) & 0x00001000 ;
1260
+ vi0 |= (qh0 << 18 ) & 0x00100000 ;
1261
+ vi0 |= (qh0 << 25 ) & 0x10000000 ;
1262
+ vi0 = __vsub4 (vi0, 0x10101010 );
1263
+ int sumi = __dp4a (vi0, ui0, 0 );
1264
+
1265
+ int vi1 = (qs >> 4 ) & 0x0F0F0F0F ;
1266
+ vi1 |= (qh1 << 4 ) & 0x00000010 ;
1267
+ vi1 |= (qh1 << 11 ) & 0x00001000 ;
1268
+ vi1 |= (qh1 << 18 ) & 0x00100000 ;
1269
+ vi1 |= (qh1 << 25 ) & 0x10000000 ;
1270
+ vi1 = __vsub4 (vi1, 0x10101010 );
1271
+ sumi = __dp4a (vi1, ui1, sumi);
1244
1272
1273
+ return sumi*d;
1245
1274
}
1246
1275
1247
1276
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -1796,6 +1825,15 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
1796
1825
<<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
1797
1826
}
1798
1827
1828
+ static void mul_mat_vec_q5_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1829
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
1830
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1 ) / GGML_CUDA_DMMV_Y;
1831
+ const dim3 block_nums (1 , block_num_y, 1 );
1832
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
1833
+ mul_mat_vec_q<QK5_0, block_q5_0, vec_dot_q5_0_q8_1>
1834
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
1835
+ }
1836
+
1799
1837
static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
1800
1838
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
1801
1839
dequantize_block<1 , 1 , convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
@@ -2319,6 +2357,9 @@ inline void ggml_cuda_op_mul_mat_vec_q(
2319
2357
case GGML_TYPE_Q4_1:
2320
2358
mul_mat_vec_q4_1_q8_1_cuda (src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
2321
2359
break ;
2360
+ case GGML_TYPE_Q5_0:
2361
+ mul_mat_vec_q5_0_q8_1_cuda (src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
2362
+ break ;
2322
2363
default :
2323
2364
GGML_ASSERT (false );
2324
2365
break ;
@@ -2875,7 +2916,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
2875
2916
} else if (ggml_is_quantized (src0->type ) || src0->type == GGML_TYPE_F16) {
2876
2917
if (src1->ne [1 ] == 1 && src0->ne [0 ] % GGML_CUDA_DMMV_X == 0 && src0->ne [1 ] % GGML_CUDA_DMMV_Y == 0 ) {
2877
2918
bool use_mul_mat_vec_q = false ;
2878
- use_mul_mat_vec_q = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1;
2919
+ use_mul_mat_vec_q = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0-> type == GGML_TYPE_Q5_0 ;
2879
2920
if (use_mul_mat_vec_q) {
2880
2921
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, false , false );
2881
2922
} else {
0 commit comments