@@ -715,7 +715,8 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
715
715
716
716
// ================================== k-quants
717
717
718
- static __global__ void dequantize_block_q2_K (const void * __restrict__ vx, float * __restrict__ yy) {
718
+ template <typename dst_t >
719
+ static __global__ void dequantize_block_q2_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
719
720
720
721
const int i = blockIdx .x ;
721
722
const block_q2_K * x = (const block_q2_K *) vx;
@@ -727,7 +728,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
727
728
const int is = 8 *n + l/16 ;
728
729
729
730
const uint8_t q = x[i].qs [32 *n + l];
730
- float * y = yy + i*QK_K + 128 *n;
731
+ dst_t * y = yy + i*QK_K + 128 *n;
731
732
732
733
float dall = __low2half (x[i].dm );
733
734
float dmin = __high2half (x[i].dm );
@@ -739,7 +740,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
739
740
const int is = tid/16 ; // 0 or 1
740
741
const int il = tid%16 ; // 0...15
741
742
const uint8_t q = x[i].qs [il] >> (2 *is);
742
- float * y = yy + i*QK_K + 16 *is + il;
743
+ dst_t * y = yy + i*QK_K + 16 *is + il;
743
744
float dall = __low2half (x[i].dm );
744
745
float dmin = __high2half (x[i].dm );
745
746
y[ 0 ] = dall * (x[i].scales [is+0 ] & 0xF ) * ((q >> 0 ) & 3 ) - dmin * (x[i].scales [is+0 ] >> 4 );
@@ -748,7 +749,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
748
749
749
750
}
750
751
751
- static __global__ void dequantize_block_q3_K (const void * __restrict__ vx, float * __restrict__ yy) {
752
+ template <typename dst_t >
753
+ static __global__ void dequantize_block_q3_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
752
754
753
755
const int i = blockIdx .x ;
754
756
const block_q3_K * x = (const block_q3_K *) vx;
@@ -772,7 +774,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
772
774
float d_all = x[i].d ;
773
775
float dl = d_all * (us - 32 );
774
776
775
- float * y = yy + i*QK_K + 128 *n + 32 *j;
777
+ dst_t * y = yy + i*QK_K + 128 *n + 32 *j;
776
778
const uint8_t * q = x[i].qs + 32 *n;
777
779
const uint8_t * hm = x[i].hmask ;
778
780
@@ -784,7 +786,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
784
786
const int im = il/8 ; // 0...1
785
787
const int in = il%8 ; // 0...7
786
788
787
- float * y = yy + i*QK_K + 16 *is + il;
789
+ dst_t * y = yy + i*QK_K + 16 *is + il;
788
790
789
791
const uint8_t q = x[i].qs [il] >> (2 *is);
790
792
const uint8_t h = x[i].hmask [in] >> (2 *is + im);
@@ -812,7 +814,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
812
814
}
813
815
#endif
814
816
815
- static __global__ void dequantize_block_q4_K (const void * __restrict__ vx, float * __restrict__ yy) {
817
+ template <typename dst_t >
818
+ static __global__ void dequantize_block_q4_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
816
819
const block_q4_K * x = (const block_q4_K *) vx;
817
820
818
821
const int i = blockIdx .x ;
@@ -825,7 +828,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
825
828
const int is = 2 *il;
826
829
const int n = 4 ;
827
830
828
- float * y = yy + i*QK_K + 64 *il + n*ir;
831
+ dst_t * y = yy + i*QK_K + 64 *il + n*ir;
829
832
830
833
const float dall = __low2half (x[i].dm );
831
834
const float dmin = __high2half (x[i].dm );
@@ -844,15 +847,16 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
844
847
#else
845
848
const int tid = threadIdx .x ;
846
849
const uint8_t * q = x[i].qs ;
847
- float * y = yy + i*QK_K;
850
+ dst_t * y = yy + i*QK_K;
848
851
const float d = (float )x[i].dm [0 ];
849
852
const float m = (float )x[i].dm [1 ];
850
853
y[tid+ 0 ] = d * (x[i].scales [0 ] & 0xF ) * (q[tid] & 0xF ) - m * (x[i].scales [0 ] >> 4 );
851
854
y[tid+32 ] = d * (x[i].scales [1 ] & 0xF ) * (q[tid] >> 4 ) - m * (x[i].scales [1 ] >> 4 );
852
855
#endif
853
856
}
854
857
855
- static __global__ void dequantize_block_q5_K (const void * __restrict__ vx, float * __restrict__ yy) {
858
+ template <typename dst_t >
859
+ static __global__ void dequantize_block_q5_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
856
860
const block_q5_K * x = (const block_q5_K *) vx;
857
861
858
862
const int i = blockIdx .x ;
@@ -864,7 +868,7 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
864
868
const int ir = tid%16 ; // ir is in 0...15
865
869
const int is = 2 *il; // is is in 0...6
866
870
867
- float * y = yy + i*QK_K + 64 *il + 2 *ir;
871
+ dst_t * y = yy + i*QK_K + 64 *il + 2 *ir;
868
872
869
873
const float dall = __low2half (x[i].dm );
870
874
const float dmin = __high2half (x[i].dm );
@@ -892,13 +896,14 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
892
896
const int is = tid/16 ; // 0 or 1
893
897
const uint8_t h = x[i].qh [in] >> im;
894
898
const float d = x[i].d ;
895
- float * y = yy + i*QK_K + tid;
899
+ dst_t * y = yy + i*QK_K + tid;
896
900
y[ 0 ] = d * x[i].scales [is+0 ] * ((q & 0xF ) - ((h >> 0 ) & 1 ? 0 : 16 ));
897
901
y[32 ] = d * x[i].scales [is+2 ] * ((q >> 4 ) - ((h >> 4 ) & 1 ? 0 : 16 ));
898
902
#endif
899
903
}
900
904
901
- static __global__ void dequantize_block_q6_K (const void * __restrict__ vx, float * __restrict__ yy) {
905
+ template <typename dst_t >
906
+ static __global__ void dequantize_block_q6_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
902
907
const block_q6_K * x = (const block_q6_K *) vx;
903
908
904
909
const int i = blockIdx .x ;
@@ -910,7 +915,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
910
915
const int il = tid - 32 *ip; // 0...32
911
916
const int is = 8 *ip + il/16 ;
912
917
913
- float * y = yy + i*QK_K + 128 *ip + il;
918
+ dst_t * y = yy + i*QK_K + 128 *ip + il;
914
919
915
920
const float d = x[i].d ;
916
921
@@ -929,7 +934,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
929
934
const int ip = tid/16 ; // 0 or 1
930
935
const int il = tid - 16 *ip; // 0...15
931
936
932
- float * y = yy + i*QK_K + 16 *ip + il;
937
+ dst_t * y = yy + i*QK_K + 16 *ip + il;
933
938
934
939
const float d = x[i].d ;
935
940
@@ -4604,32 +4609,38 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
4604
4609
quantize_q8_1<<<num_blocks, block_size, 0 , stream>>> (x, vy, kx, kx_padded);
4605
4610
}
4606
4611
4607
- static void dequantize_row_q4_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4612
+ template <typename dst_t >
4613
+ static void dequantize_row_q4_0_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4608
4614
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4609
4615
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
4610
4616
}
4611
4617
4612
- static void dequantize_row_q4_1_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4618
+ template <typename dst_t >
4619
+ static void dequantize_row_q4_1_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4613
4620
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4614
4621
dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
4615
4622
}
4616
4623
4617
- static void dequantize_row_q5_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4624
+ template <typename dst_t >
4625
+ static void dequantize_row_q5_0_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4618
4626
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4619
4627
dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
4620
4628
}
4621
4629
4622
- static void dequantize_row_q5_1_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4630
+ template <typename dst_t >
4631
+ static void dequantize_row_q5_1_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4623
4632
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4624
4633
dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
4625
4634
}
4626
4635
4627
- static void dequantize_row_q8_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4636
+ template <typename dst_t >
4637
+ static void dequantize_row_q8_0_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4628
4638
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
4629
4639
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
4630
4640
}
4631
4641
4632
- static void dequantize_row_q2_K_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4642
+ template <typename dst_t >
4643
+ static void dequantize_row_q2_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4633
4644
const int nb = k / QK_K;
4634
4645
#if QK_K == 256
4635
4646
dequantize_block_q2_K<<<nb, 64 , 0 , stream>>> (vx, y);
@@ -4638,7 +4649,8 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu
4638
4649
#endif
4639
4650
}
4640
4651
4641
- static void dequantize_row_q3_K_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4652
+ template <typename dst_t >
4653
+ static void dequantize_row_q3_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4642
4654
const int nb = k / QK_K;
4643
4655
#if QK_K == 256
4644
4656
dequantize_block_q3_K<<<nb, 64 , 0 , stream>>> (vx, y);
@@ -4647,21 +4659,23 @@ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cu
4647
4659
#endif
4648
4660
}
4649
4661
4650
- static void dequantize_row_q4_K_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4662
+ template <typename dst_t >
4663
+ static void dequantize_row_q4_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4651
4664
const int nb = k / QK_K;
4652
4665
dequantize_block_q4_K<<<nb, 32 , 0 , stream>>> (vx, y);
4653
4666
}
4654
4667
4655
- static void dequantize_row_q5_K_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4668
+ template <typename dst_t >
4669
+ static void dequantize_row_q5_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4656
4670
const int nb = k / QK_K;
4657
4671
#if QK_K == 256
4658
4672
dequantize_block_q5_K<<<nb, 64 , 0 , stream>>> (vx, y);
4659
4673
#else
4660
4674
dequantize_block_q5_K<<<nb, 32 , 0 , stream>>> (vx, y);
4661
4675
#endif
4662
4676
}
4663
-
4664
- static void dequantize_row_q6_K_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
4677
+ template < typename dst_t >
4678
+ static void dequantize_row_q6_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
4665
4679
const int nb = k / QK_K;
4666
4680
#if QK_K == 256
4667
4681
dequantize_block_q6_K<<<nb, 64 , 0 , stream>>> (vx, y);
@@ -4868,6 +4882,26 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
4868
4882
4869
4883
static to_fp16_cuda_t ggml_get_to_fp16_cuda (ggml_type type) {
4870
4884
switch (type) {
4885
+ case GGML_TYPE_Q4_0:
4886
+ return dequantize_row_q4_0_cuda;
4887
+ case GGML_TYPE_Q4_1:
4888
+ return dequantize_row_q4_1_cuda;
4889
+ case GGML_TYPE_Q5_0:
4890
+ return dequantize_row_q5_0_cuda;
4891
+ case GGML_TYPE_Q5_1:
4892
+ return dequantize_row_q5_1_cuda;
4893
+ case GGML_TYPE_Q8_0:
4894
+ return dequantize_row_q8_0_cuda;
4895
+ case GGML_TYPE_Q2_K:
4896
+ return dequantize_row_q2_K_cuda;
4897
+ case GGML_TYPE_Q3_K:
4898
+ return dequantize_row_q3_K_cuda;
4899
+ case GGML_TYPE_Q4_K:
4900
+ return dequantize_row_q4_K_cuda;
4901
+ case GGML_TYPE_Q5_K:
4902
+ return dequantize_row_q5_K_cuda;
4903
+ case GGML_TYPE_Q6_K:
4904
+ return dequantize_row_q6_K_cuda;
4871
4905
case GGML_TYPE_F32:
4872
4906
return convert_fp32_to_fp16_cuda;
4873
4907
default :
@@ -6083,8 +6117,19 @@ inline void ggml_cuda_op_mul_mat_cublas(
6083
6117
6084
6118
const int compute_capability = g_compute_capabilities[id];
6085
6119
6086
- if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous (src0) && ldc == row_diff) {
6087
- // convert src1 to fp16, multiply as fp16, convert dst to fp32
6120
+ if (compute_capability >= CC_TURING && (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && ldc == row_diff) {
6121
+ // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
6122
+ half * src0_as_f16 = nullptr ;
6123
+ size_t src0_as = 0 ;
6124
+ if (src0->type != GGML_TYPE_F16) {
6125
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src0->type );
6126
+ GGML_ASSERT (to_fp16_cuda != nullptr );
6127
+ size_t ne = row_diff*ne00;
6128
+ src0_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src0_as);
6129
+ to_fp16_cuda (src0_dd_i, src0_as_f16, ne, stream);
6130
+ }
6131
+ const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (half *) src0_dd_i : src0_as_f16;
6132
+
6088
6133
half * src1_as_f16 = nullptr ;
6089
6134
size_t src1_as = 0 ;
6090
6135
if (src1->type != GGML_TYPE_F16) {
@@ -6106,9 +6151,9 @@ inline void ggml_cuda_op_mul_mat_cublas(
6106
6151
CUBLAS_CHECK (
6107
6152
cublasGemmEx (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6108
6153
row_diff, src1_ncols, ne10,
6109
- &alpha_f16, src0_dd_i , CUDA_R_16F, ne00,
6110
- src1_ptr, CUDA_R_16F, ne10,
6111
- &beta_f16, dst_f16, CUDA_R_16F, ldc,
6154
+ &alpha_f16, src0_ptr , CUDA_R_16F, ne00,
6155
+ src1_ptr, CUDA_R_16F, ne10,
6156
+ &beta_f16, dst_f16, CUDA_R_16F, ldc,
6112
6157
CUBLAS_COMPUTE_16F,
6113
6158
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
6114
6159
@@ -6117,6 +6162,10 @@ inline void ggml_cuda_op_mul_mat_cublas(
6117
6162
6118
6163
ggml_cuda_pool_free (dst_f16, dst_as);
6119
6164
6165
+ if (src0_as != 0 ) {
6166
+ ggml_cuda_pool_free (src0_as_f16, src0_as);
6167
+ }
6168
+
6120
6169
if (src1_as != 0 ) {
6121
6170
ggml_cuda_pool_free (src1_as_f16, src1_as);
6122
6171
}
0 commit comments