Skip to content

Commit cf94b82

Browse files
committed
kompute: add mul_mat_q4_k shader
This is a more or less direct translation from the Metal implementation to GLSL. Signed-off-by: Sergio Lopez <[email protected]>
1 parent aa303fe commit cf94b82

File tree

4 files changed

+186
-0
lines changed

4 files changed

+186
-0
lines changed

ggml/src/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
800800
kompute-shaders/op_mul_mat_q8_0.comp
801801
kompute-shaders/op_mul_mat_q4_0.comp
802802
kompute-shaders/op_mul_mat_q4_1.comp
803+
kompute-shaders/op_mul_mat_q4_k.comp
803804
kompute-shaders/op_mul_mat_q6_k.comp
804805
kompute-shaders/op_getrows_f32.comp
805806
kompute-shaders/op_getrows_f16.comp
@@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
833834
shaderop_mul_mat_q8_0.h
834835
shaderop_mul_mat_q4_0.h
835836
shaderop_mul_mat_q4_1.h
837+
shaderop_mul_mat_q4_k.h
836838
shaderop_mul_mat_q6_k.h
837839
shaderop_getrows_f32.h
838840
shaderop_getrows_f16.h

ggml/src/ggml-kompute.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "shaderop_mul_mat_q8_0.h"
2121
#include "shaderop_mul_mat_q4_0.h"
2222
#include "shaderop_mul_mat_q4_1.h"
23+
#include "shaderop_mul_mat_q4_k.h"
2324
#include "shaderop_mul_mat_q6_k.h"
2425
#include "shaderop_mul_mat_mat_f32.h"
2526
#include "shaderop_getrows_f32.h"
@@ -1076,6 +1077,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
10761077
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
10771078
}
10781079

1080+
static void ggml_vk_mul_mat_q4_k(
1081+
kp::Sequence& seq,
1082+
const std::shared_ptr<kp::Tensor>& inA,
1083+
const std::shared_ptr<kp::Tensor>& inB,
1084+
const std::shared_ptr<kp::Tensor>& out,
1085+
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1086+
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
1087+
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
1088+
int32_t ne1, int32_t r2, int32_t r3
1089+
) {
1090+
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
1091+
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
1092+
1093+
struct PushConstants {
1094+
uint32_t inAOff, inBOff, outOff;
1095+
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
1096+
} pushConsts {
1097+
0, 0, 0,
1098+
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
1099+
};
1100+
1101+
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1102+
if (!komputeManager()->hasAlgorithm(__func__)) {
1103+
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
1104+
} else {
1105+
s_algo = komputeManager()->getAlgorithm(__func__);
1106+
s_algo->setTensors({inA, inB, out});
1107+
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
1108+
s_algo->setPushConstants<PushConstants>({pushConsts});
1109+
s_algo->updateDescriptors(s_kompute_context->pool.get());
1110+
}
1111+
seq.record<kp::OpAlgoDispatch>(s_algo);
1112+
}
1113+
10791114
static void ggml_vk_mul_mat_q6_k(
10801115
kp::Sequence& seq,
10811116
const std::shared_ptr<kp::Tensor>& inA,
@@ -1393,6 +1428,7 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
13931428
case GGML_TYPE_Q8_0:
13941429
case GGML_TYPE_Q4_0:
13951430
case GGML_TYPE_Q4_1:
1431+
case GGML_TYPE_Q4_K:
13961432
return true;
13971433
default:
13981434
;
@@ -1651,6 +1687,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
16511687
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
16521688
);
16531689
break;
1690+
case GGML_TYPE_Q4_K:
1691+
ggml_vk_mul_mat_q4_k(
1692+
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1693+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
1694+
);
1695+
break;
16541696
case GGML_TYPE_Q6_K:
16551697
ggml_vk_mul_mat_q6_k(
16561698
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,

ggml/src/kompute-shaders/common.comp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define TWOPI_F 6.283185307179586f
1616

1717
#define QK_K 256
18+
#define K_SCALE_SIZE 12
1819

1920
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
2021
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
@@ -64,6 +65,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
6465
return reg;
6566
}
6667

68+
#define sizeof_block_q4_k 144
69+
struct block_q4_k {
70+
float16_t d;
71+
float16_t dmin;
72+
uint8_t scales[K_SCALE_SIZE];
73+
uint8_t qs[QK_K/2];
74+
};
75+
6776
#define sizeof_block_q6_k 210
6877
struct block_q6_k {
6978
uint8_t ql[QK_K/2]; // quants, lower 4 bits
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#version 450
2+
3+
#include "common.comp"
4+
5+
#define N_DST 4
6+
#define SIZE_OF_BLOCK sizeof_block_q4_k
7+
8+
layout(local_size_x = 4) in;
9+
layout(local_size_y = 8) in;
10+
layout(local_size_z = 1) in;
11+
12+
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
13+
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
14+
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
15+
16+
layout (push_constant) uniform parameter {
17+
uint inAOff;
18+
uint inBOff;
19+
uint outOff;
20+
int ne00;
21+
int ne10;
22+
int ne0;
23+
int ne1;
24+
int ne01;
25+
int ne02;
26+
int ne12;
27+
int r2;
28+
int r3;
29+
} pcs;
30+
31+
void main() {
32+
const uint16_t kmask1 = uint16_t(0x3f3f);
33+
const uint16_t kmask2 = uint16_t(0x0f0f);
34+
const uint16_t kmask3 = uint16_t(0xc0c0);
35+
36+
const uint ix = gl_SubgroupInvocationID/8; // 0...3
37+
const uint it = gl_SubgroupInvocationID%8; // 0...7
38+
const uint iq = it/4; // 0 or 1
39+
const uint ir = it%4; // 0...3
40+
41+
const uint nb = pcs.ne00/QK_K;
42+
43+
const uint r0 = gl_WorkGroupID.x;
44+
const uint r1 = gl_WorkGroupID.y;
45+
const uint im = gl_WorkGroupID.z;
46+
47+
const uint first_row = r0 * N_DST;
48+
const uint ib_row = first_row * nb;
49+
50+
const uint i12 = im%pcs.ne12;
51+
const uint i13 = im/pcs.ne12;
52+
53+
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
54+
55+
const uint xblk = ib_row + offset0 + pcs.inAOff;
56+
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
57+
58+
float yl[16];
59+
float yh[16];
60+
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
61+
float all_sum = 0.f;
62+
63+
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
64+
65+
for (uint ib = ix; ib < nb; ib += 4) {
66+
const uint blk_idx = ib + xblk;
67+
68+
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
69+
for (int i = 0; i < 8; ++i) {
70+
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
71+
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
72+
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
73+
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
74+
}
75+
76+
for (int row = 0; row < N_DST; row++) {
77+
uint row_idx = row * nb;
78+
79+
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
80+
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
81+
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
82+
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
83+
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
84+
85+
uint16_t sc16[4];
86+
sc16[0] = sc_0 & kmask1;
87+
sc16[1] = sc_2 & kmask1;
88+
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
89+
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
90+
91+
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
92+
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
93+
for (int i = 0; i < 8; i += 2) {
94+
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
95+
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
96+
acc1[0] += yl[i+0] * (q1 & 0x000F);
97+
acc1[1] += yl[i+1] * (q1 & 0x0F00);
98+
acc1[2] += yl[i+8] * (q1 & 0x00F0);
99+
acc1[3] += yl[i+9] * (q1 & 0xF000);
100+
acc2[0] += yh[i+0] * (q2 & 0x000F);
101+
acc2[1] += yh[i+1] * (q2 & 0x0F00);
102+
acc2[2] += yh[i+8] * (q2 & 0x00F0);
103+
acc2[3] += yh[i+9] * (q2 & 0xF000);
104+
}
105+
106+
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
107+
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
108+
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
109+
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
110+
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
111+
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
112+
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
113+
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
114+
115+
float dall = float(inA[blk_idx + row_idx].d);
116+
float dmin = float(inA[blk_idx + row_idx].dmin);
117+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
118+
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
119+
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
120+
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
121+
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
122+
}
123+
124+
y4 += 4 * QK_K;
125+
}
126+
127+
for (int row = 0; row < N_DST; ++row) {
128+
all_sum = subgroupAdd(sumf[row]);
129+
if (subgroupElect()) {
130+
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
131+
}
132+
}
133+
}

0 commit comments

Comments
 (0)