Skip to content

Commit 8e55830

Browse files
CUDA: MMQ support for iq4_nl, iq4_xs (#8278)
1 parent 0a42380 commit 8e55830

File tree

7 files changed

+226
-80
lines changed

7 files changed

+226
-80
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
6868
const int iqs4 = k_KQ % QI4_0;
6969
const int shift = k_KQ & (QI8_1/2);
7070

71-
const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
71+
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
7272
const int u = Q_q8[k_KQ_0/WARP_SIZE];
7373

7474
const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -108,7 +108,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
108108
const int iqs4 = k_KQ % QI4_1;
109109
const int shift = k_KQ & (QI8_1/2);
110110

111-
const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
111+
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
112112
const int u = Q_q8[k_KQ_0/WARP_SIZE];
113113

114114
const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -153,8 +153,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
153153
const int iqs8 = k_KQ % QI8_1;
154154
const int shift = k_KQ & (QI8_1/2);
155155

156-
int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
157-
const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
156+
int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
157+
const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
158158
v |= (vh << 4) & 0x00000010; // 0 -> 4
159159
v |= (vh << 11) & 0x00001000; // 1 -> 12
160160
v |= (vh << 18) & 0x00100000; // 2 -> 20
@@ -200,8 +200,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
200200
const int iqs8 = k_KQ % QI8_1;
201201
const int shift = k_KQ & (QI8_1/2);
202202

203-
int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
204-
const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
203+
int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
204+
const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
205205
v |= (vh << 4) & 0x00000010; // 0 -> 4
206206
v |= (vh << 11) & 0x00001000; // 1 -> 12
207207
v |= (vh << 18) & 0x00100000; // 2 -> 20
@@ -249,7 +249,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
249249
const int ib = k_KQ / QI8_0;
250250
const int iqs = k_KQ % QI8_0;
251251

252-
const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
252+
const int v = get_int_b2(K_q8_0[ib].qs, iqs);
253253

254254
T Q_d;
255255
if (std::is_same<T, half>::value) {
@@ -408,7 +408,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
408408

409409
const T d = x[ib].d;
410410
const int ql0 = x[ib].qs[iqs];
411-
const int qh0 = get_int_from_uint8(x[ib].qh, 0);
411+
const int qh0 = get_int_b2(x[ib].qh, 0);
412412
const int ql = ((ql0 >> (4*shift)) & 0x0F);
413413
const int qh = ((qh0 >> idq) << 4) & 0x10;
414414
const int q = (ql | qh) - 16;
@@ -433,7 +433,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
433433

434434
const half2 dm = x[ib].dm;
435435
const int ql0 = x[ib].qs[iqs];
436-
const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
436+
const int qh0 = get_int_b4(x[ib].qh, 0);
437437
const int ql = ((ql0 >> (4*shift)) & 0x0F);
438438
const int qh = ((qh0 >> idq) << 4) & 0x10;
439439
const int q = (ql | qh);

ggml/src/ggml-cuda/mmq.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ void ggml_cuda_op_mul_mat_q(
5959
case GGML_TYPE_Q6_K:
6060
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
6161
break;
62+
case GGML_TYPE_IQ4_XS:
63+
mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
64+
break;
65+
case GGML_TYPE_IQ4_NL:
66+
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
67+
break;
6268
default:
6369
GGML_ASSERT(false);
6470
break;
@@ -87,6 +93,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
8793
case GGML_TYPE_Q4_K:
8894
case GGML_TYPE_Q5_K:
8995
case GGML_TYPE_Q6_K:
96+
case GGML_TYPE_IQ4_XS:
97+
case GGML_TYPE_IQ4_NL:
9098
mmq_supported = true;
9199
break;
92100
default:

0 commit comments

Comments
 (0)