Skip to content

Commit 62832c5

Browse files
committed
ggml-cuda : perform cublas matrix multiplication of quantized types as fp16
1 parent 40e07a6 commit 62832c5

File tree

1 file changed

+80
-31
lines changed

1 file changed

+80
-31
lines changed

ggml-cuda.cu

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,8 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
715715

716716
//================================== k-quants
717717

718-
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
718+
template<typename dst_t>
719+
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
719720

720721
const int i = blockIdx.x;
721722
const block_q2_K * x = (const block_q2_K *) vx;
@@ -727,7 +728,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
727728
const int is = 8*n + l/16;
728729

729730
const uint8_t q = x[i].qs[32*n + l];
730-
float * y = yy + i*QK_K + 128*n;
731+
dst_t * y = yy + i*QK_K + 128*n;
731732

732733
float dall = __low2half(x[i].dm);
733734
float dmin = __high2half(x[i].dm);
@@ -739,7 +740,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
739740
const int is = tid/16; // 0 or 1
740741
const int il = tid%16; // 0...15
741742
const uint8_t q = x[i].qs[il] >> (2*is);
742-
float * y = yy + i*QK_K + 16*is + il;
743+
dst_t * y = yy + i*QK_K + 16*is + il;
743744
float dall = __low2half(x[i].dm);
744745
float dmin = __high2half(x[i].dm);
745746
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
@@ -748,7 +749,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
748749

749750
}
750751

751-
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
752+
template<typename dst_t>
753+
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
752754

753755
const int i = blockIdx.x;
754756
const block_q3_K * x = (const block_q3_K *) vx;
@@ -772,7 +774,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
772774
float d_all = x[i].d;
773775
float dl = d_all * (us - 32);
774776

775-
float * y = yy + i*QK_K + 128*n + 32*j;
777+
dst_t * y = yy + i*QK_K + 128*n + 32*j;
776778
const uint8_t * q = x[i].qs + 32*n;
777779
const uint8_t * hm = x[i].hmask;
778780

@@ -784,7 +786,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
784786
const int im = il/8; // 0...1
785787
const int in = il%8; // 0...7
786788

787-
float * y = yy + i*QK_K + 16*is + il;
789+
dst_t * y = yy + i*QK_K + 16*is + il;
788790

789791
const uint8_t q = x[i].qs[il] >> (2*is);
790792
const uint8_t h = x[i].hmask[in] >> (2*is + im);
@@ -812,7 +814,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
812814
}
813815
#endif
814816

815-
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
817+
template<typename dst_t>
818+
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
816819
const block_q4_K * x = (const block_q4_K *) vx;
817820

818821
const int i = blockIdx.x;
@@ -825,7 +828,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
825828
const int is = 2*il;
826829
const int n = 4;
827830

828-
float * y = yy + i*QK_K + 64*il + n*ir;
831+
dst_t * y = yy + i*QK_K + 64*il + n*ir;
829832

830833
const float dall = __low2half(x[i].dm);
831834
const float dmin = __high2half(x[i].dm);
@@ -844,15 +847,16 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
844847
#else
845848
const int tid = threadIdx.x;
846849
const uint8_t * q = x[i].qs;
847-
float * y = yy + i*QK_K;
850+
dst_t * y = yy + i*QK_K;
848851
const float d = (float)x[i].dm[0];
849852
const float m = (float)x[i].dm[1];
850853
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
851854
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
852855
#endif
853856
}
854857

855-
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
858+
template<typename dst_t>
859+
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
856860
const block_q5_K * x = (const block_q5_K *) vx;
857861

858862
const int i = blockIdx.x;
@@ -864,7 +868,7 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
864868
const int ir = tid%16; // ir is in 0...15
865869
const int is = 2*il; // is is in 0...6
866870

867-
float * y = yy + i*QK_K + 64*il + 2*ir;
871+
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
868872

869873
const float dall = __low2half(x[i].dm);
870874
const float dmin = __high2half(x[i].dm);
@@ -892,13 +896,14 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
892896
const int is = tid/16; // 0 or 1
893897
const uint8_t h = x[i].qh[in] >> im;
894898
const float d = x[i].d;
895-
float * y = yy + i*QK_K + tid;
899+
dst_t * y = yy + i*QK_K + tid;
896900
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
897901
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
898902
#endif
899903
}
900904

901-
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
905+
template<typename dst_t>
906+
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
902907
const block_q6_K * x = (const block_q6_K *) vx;
903908

904909
const int i = blockIdx.x;
@@ -910,7 +915,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
910915
const int il = tid - 32*ip; // 0...32
911916
const int is = 8*ip + il/16;
912917

913-
float * y = yy + i*QK_K + 128*ip + il;
918+
dst_t * y = yy + i*QK_K + 128*ip + il;
914919

915920
const float d = x[i].d;
916921

@@ -929,7 +934,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
929934
const int ip = tid/16; // 0 or 1
930935
const int il = tid - 16*ip; // 0...15
931936

932-
float * y = yy + i*QK_K + 16*ip + il;
937+
dst_t * y = yy + i*QK_K + 16*ip + il;
933938

934939
const float d = x[i].d;
935940

@@ -4604,32 +4609,38 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
46044609
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
46054610
}
46064611

4607-
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4612+
template<typename dst_t>
4613+
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46084614
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
46094615
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
46104616
}
46114617

4612-
static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4618+
template<typename dst_t>
4619+
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46134620
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
46144621
dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
46154622
}
46164623

4617-
static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4624+
template<typename dst_t>
4625+
static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46184626
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
46194627
dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
46204628
}
46214629

4622-
static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4630+
template<typename dst_t>
4631+
static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46234632
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
46244633
dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
46254634
}
46264635

4627-
static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4636+
template<typename dst_t>
4637+
static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46284638
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
46294639
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
46304640
}
46314641

4632-
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4642+
template<typename dst_t>
4643+
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46334644
const int nb = k / QK_K;
46344645
#if QK_K == 256
46354646
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4638,7 +4649,8 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu
46384649
#endif
46394650
}
46404651

4641-
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4652+
template<typename dst_t>
4653+
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46424654
const int nb = k / QK_K;
46434655
#if QK_K == 256
46444656
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4647,21 +4659,23 @@ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cu
46474659
#endif
46484660
}
46494661

4650-
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4662+
template<typename dst_t>
4663+
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46514664
const int nb = k / QK_K;
46524665
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
46534666
}
46544667

4655-
static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4668+
template<typename dst_t>
4669+
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46564670
const int nb = k / QK_K;
46574671
#if QK_K == 256
46584672
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
46594673
#else
46604674
dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
46614675
#endif
46624676
}
4663-
4664-
static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
4677+
template<typename dst_t>
4678+
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
46654679
const int nb = k / QK_K;
46664680
#if QK_K == 256
46674681
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4868,6 +4882,26 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
48684882

48694883
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
48704884
switch (type) {
4885+
case GGML_TYPE_Q4_0:
4886+
return dequantize_row_q4_0_cuda;
4887+
case GGML_TYPE_Q4_1:
4888+
return dequantize_row_q4_1_cuda;
4889+
case GGML_TYPE_Q5_0:
4890+
return dequantize_row_q5_0_cuda;
4891+
case GGML_TYPE_Q5_1:
4892+
return dequantize_row_q5_1_cuda;
4893+
case GGML_TYPE_Q8_0:
4894+
return dequantize_row_q8_0_cuda;
4895+
case GGML_TYPE_Q2_K:
4896+
return dequantize_row_q2_K_cuda;
4897+
case GGML_TYPE_Q3_K:
4898+
return dequantize_row_q3_K_cuda;
4899+
case GGML_TYPE_Q4_K:
4900+
return dequantize_row_q4_K_cuda;
4901+
case GGML_TYPE_Q5_K:
4902+
return dequantize_row_q5_K_cuda;
4903+
case GGML_TYPE_Q6_K:
4904+
return dequantize_row_q6_K_cuda;
48714905
case GGML_TYPE_F32:
48724906
return convert_fp32_to_fp16_cuda;
48734907
default:
@@ -6083,8 +6117,19 @@ inline void ggml_cuda_op_mul_mat_cublas(
60836117

60846118
const int compute_capability = g_compute_capabilities[id];
60856119

6086-
if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous(src0) && ldc == row_diff) {
6087-
// convert src1 to fp16, multiply as fp16, convert dst to fp32
6120+
if (compute_capability >= CC_TURING && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && ldc == row_diff) {
6121+
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
6122+
half * src0_as_f16 = nullptr;
6123+
size_t src0_as = 0;
6124+
if (src0->type != GGML_TYPE_F16) {
6125+
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
6126+
GGML_ASSERT(to_fp16_cuda != nullptr);
6127+
size_t ne = row_diff*ne00;
6128+
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
6129+
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
6130+
}
6131+
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (half *) src0_dd_i : src0_as_f16;
6132+
60886133
half * src1_as_f16 = nullptr;
60896134
size_t src1_as = 0;
60906135
if (src1->type != GGML_TYPE_F16) {
@@ -6106,9 +6151,9 @@ inline void ggml_cuda_op_mul_mat_cublas(
61066151
CUBLAS_CHECK(
61076152
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
61086153
row_diff, src1_ncols, ne10,
6109-
&alpha_f16, src0_dd_i, CUDA_R_16F, ne00,
6110-
src1_ptr, CUDA_R_16F, ne10,
6111-
&beta_f16, dst_f16, CUDA_R_16F, ldc,
6154+
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
6155+
src1_ptr, CUDA_R_16F, ne10,
6156+
&beta_f16, dst_f16, CUDA_R_16F, ldc,
61126157
CUBLAS_COMPUTE_16F,
61136158
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
61146159

@@ -6117,6 +6162,10 @@ inline void ggml_cuda_op_mul_mat_cublas(
61176162

61186163
ggml_cuda_pool_free(dst_f16, dst_as);
61196164

6165+
if (src0_as != 0) {
6166+
ggml_cuda_pool_free(src0_as_f16, src0_as);
6167+
}
6168+
61206169
if (src1_as != 0) {
61216170
ggml_cuda_pool_free(src1_as_f16, src1_as);
61226171
}

0 commit comments

Comments
 (0)