Skip to content

Commit f0f71a5

Browse files
CUDA: MMQ code deduplication + iquant support
1 parent 87e397d commit f0f71a5

File tree

10 files changed

+808
-647
lines changed

10 files changed

+808
-647
lines changed

ggml/src/ggml-cuda/mmq.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,24 @@ void ggml_cuda_op_mul_mat_q(
5959
case GGML_TYPE_Q6_K:
6060
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
6161
break;
62+
case GGML_TYPE_IQ2_XXS:
63+
mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
64+
break;
65+
case GGML_TYPE_IQ2_XS:
66+
mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
67+
break;
68+
case GGML_TYPE_IQ2_S:
69+
mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
70+
break;
71+
case GGML_TYPE_IQ3_XXS:
72+
mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
73+
break;
74+
case GGML_TYPE_IQ3_S:
75+
mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
76+
break;
77+
case GGML_TYPE_IQ1_S:
78+
mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
79+
break;
6280
case GGML_TYPE_IQ4_XS:
6381
mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
6482
break;
@@ -93,6 +111,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
93111
case GGML_TYPE_Q4_K:
94112
case GGML_TYPE_Q5_K:
95113
case GGML_TYPE_Q6_K:
114+
case GGML_TYPE_IQ2_XXS:
115+
case GGML_TYPE_IQ2_XS:
116+
case GGML_TYPE_IQ2_S:
117+
case GGML_TYPE_IQ3_XXS:
118+
case GGML_TYPE_IQ3_S:
119+
case GGML_TYPE_IQ1_S:
96120
case GGML_TYPE_IQ4_XS:
97121
case GGML_TYPE_IQ4_NL:
98122
mmq_supported = true;

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 731 additions & 646 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
TYPES_MMQ = [
2424
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
2525
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
26-
"GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
26+
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
27+
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
2728
]
2829

2930
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,27 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
188188
return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
189189
}
190190

191+
template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl(
192+
const int * v, const int * u, const float * d8_0, const float & d8_1) {
193+
194+
float sumf = 0.0f;
195+
196+
#pragma unroll
197+
for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) {
198+
int sumi = 0;
199+
200+
#pragma unroll
201+
for (int i = i0; i < i0 + QI8_0/2; ++i) {
202+
// SIMD dot product of quantized values
203+
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
204+
}
205+
206+
sumf += d8_0[i0/(QI8_0/2)]*sumi;
207+
}
208+
209+
return d8_1*sumf;
210+
}
211+
191212
#define VDR_Q2_K_Q8_1_MMVQ 1
192213
#define VDR_Q2_K_Q8_1_MMQ 4
193214

0 commit comments

Comments
 (0)