Skip to content

Commit c03579d

Browse files
vulkan: initial support for IQ3_XXS
1 parent 3d2c7a0 commit c03579d

File tree

10 files changed

+255
-6
lines changed

10 files changed

+255
-6
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 45 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,54 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
8888
}
8989
#endif
9090

91+
#if defined(DATA_A_IQ3_XXS)
92+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
93+
const uint ib4 = iqs / 4;
94+
const uint ib32 = iqs / 32;
95+
const uint is = QUANT_K / 4 + 4 * ib32;
96+
const uint qs = data_a[a_offset + ib].qs[ib4];
97+
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
98+
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
99+
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
100+
const float db = 0.5 * (0.5 + (signs >> 28));
101+
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
102+
// Add parity bit
103+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
104+
const uint sign = sign8 >> (iqs % 8);
105+
const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));
106+
bool sign0 = (sign & 1) != 0;
107+
bool sign1 = (sign & 2) != 0;
108+
return db * vec2(
109+
grid.x * (sign0 ? -1.0 : 1.0),
110+
grid.y * (sign1 ? -1.0 : 1.0)
111+
);
112+
}
113+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
114+
const uint ib4 = iqs / 4;
115+
const uint ib32 = iqs / 32;
116+
const uint is = QUANT_K / 4 + 4 * ib32;
117+
const uint qs = data_a[a_offset + ib].qs[ib4];
118+
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
119+
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
120+
const float db = 0.5 * (0.5 + (signs >> 28));
121+
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
122+
// Add parity bit
123+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
124+
const uint sign = sign8 >> (iqs % 8);
125+
const u8vec4 grid = unpack8(iq3xxs_grid[qs]);
126+
bool sign0 = (sign & 1) != 0;
127+
bool sign1 = (sign & 2) != 0;
128+
bool sign2 = (sign & 4) != 0;
129+
bool sign3 = (sign & 8) != 0;
130+
return db * vec4(
131+
grid.x * (sign0 ? -1.0 : 1.0),
132+
grid.y * (sign1 ? -1.0 : 1.0),
133+
grid.z * (sign2 ? -1.0 : 1.0),
134+
grid.w * (sign3 ? -1.0 : 1.0)
135+
);
136+
}
137+
#endif
138+
91139
#if defined(DATA_A_IQ3_S)
92140
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
93141
const uint qs = data_a[a_offset + ib].qs[iqs / 4];
@@ -142,7 +190,7 @@ vec2 get_dm(uint ib, uint a_offset) {
142190
}
143191
#endif
144192

145-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
193+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
146194
vec2 get_dm(uint ib, uint a_offset) {
147195
return vec2(float(data_a[a_offset + ib].d), 0);
148196
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#version 450
2+
3+
#include "dequant_head.comp"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
// Each thread handles 1 scale block (32 values)
12+
// 8 threads handle 1 superblock
13+
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
14+
15+
init_iq3xxs_shmem();
16+
17+
if (ib >= p.nel / 256) {
18+
return;
19+
}
20+
21+
const uint is = gl_LocalInvocationID.x % 8;
22+
const uint b_idx = 256 * ib + 32 * is;
23+
const uint s_idx = QUANT_K / 4 + 4 * is;
24+
25+
const float d = float(data_a[ib].d);
26+
uint signscale = pack32(u8vec4(
27+
data_a[ib].qs[s_idx + 0],
28+
data_a[ib].qs[s_idx + 1],
29+
data_a[ib].qs[s_idx + 2],
30+
data_a[ib].qs[s_idx + 3]
31+
));
32+
const float db = d * (0.5 + (signscale >> 28)) * 0.5;
33+
34+
[[unroll]] for (uint l = 0; l < 4; ++l) {
35+
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
36+
// Restore parity bit.
37+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
38+
const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]);
39+
const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]);
40+
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
41+
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
42+
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
43+
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
44+
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
45+
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
46+
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
47+
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
48+
}
49+
}

ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ void main() {
1212
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
1313
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
1414

15-
#if defined(DATA_A_IQ3_S)
15+
#if defined(DATA_A_IQ3_XXS)
16+
init_iq3xxs_shmem();
17+
#elif defined(DATA_A_IQ3_S)
1618
init_iq3s_shmem();
1719
#elif defined(DATA_A_IQ4_NL)
1820
init_iq4nl_shmem();

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
133133
void main() {
134134
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
135135

136-
#if defined(DATA_A_IQ3_S)
136+
#if defined(DATA_A_IQ3_XXS)
137+
init_iq3xxs_shmem();
138+
#elif defined(DATA_A_IQ3_S)
137139
init_iq3s_shmem();
138140
#elif defined(DATA_A_IQ4_NL)
139141
init_iq4nl_shmem();

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
9595
#endif
9696

9797
void main() {
98-
#if defined(DATA_A_IQ3_S)
98+
#if defined(DATA_A_IQ3_XXS)
99+
init_iq3xxs_shmem();
100+
#elif defined(DATA_A_IQ3_S)
99101
init_iq3s_shmem();
100102
#elif defined(DATA_A_IQ4_NL)
101103
init_iq4nl_shmem();
@@ -441,6 +443,31 @@ void main() {
441443

442444
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
443445
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
446+
#elif defined(DATA_A_IQ3_XXS)
447+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
448+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
449+
450+
const uint ib = idx / 128; // 2 values per idx
451+
const uint iqs = (idx % 128) / 2; // 0..63
452+
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
453+
454+
const float d = float(data_a[ib].d);
455+
const uint qs = data_a[ib].qs[iqs];
456+
const uint signs = pack32(u8vec4(
457+
data_a[ib].qs[is+0],
458+
data_a[ib].qs[is+1],
459+
data_a[ib].qs[is+2],
460+
data_a[ib].qs[is+3]
461+
));
462+
const float db = d * 0.5 * (0.5 + (signs >> 28));
463+
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
464+
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
465+
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
466+
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
467+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
468+
469+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
470+
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
444471
#elif defined(DATA_A_IQ3_S)
445472
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
446473
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
106106
#endif
107107

108108
void main() {
109-
#if defined(DATA_A_IQ3_S)
109+
#if defined(DATA_A_IQ3_XXS)
110+
init_iq3xxs_shmem();
111+
#elif defined(DATA_A_IQ3_S)
110112
init_iq3s_shmem();
111113
#elif defined(DATA_A_IQ4_NL)
112114
init_iq4nl_shmem();

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,77 @@ struct block_q6_K_packed16
294294

295295
// IQuants
296296

297+
#define QUANT_K_IQ3_XXS 256
298+
#define QUANT_R_IQ3_XXS 1
299+
300+
struct block_iq3_xxs
301+
{
302+
float16_t d;
303+
uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8];
304+
};
305+
306+
struct block_iq3_xxs_packed16
307+
{
308+
float16_t d;
309+
uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16];
310+
};
311+
312+
#if defined(DATA_A_IQ3_XXS)
313+
314+
const uint32_t iq3xxs_grid_const[256] = {
315+
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
316+
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
317+
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
318+
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
319+
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
320+
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
321+
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
322+
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
323+
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
324+
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
325+
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
326+
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
327+
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
328+
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
329+
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
330+
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
331+
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
332+
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
333+
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
334+
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
335+
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
336+
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
337+
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
338+
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
339+
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
340+
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
341+
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
342+
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
343+
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
344+
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
345+
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
346+
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
347+
};
348+
349+
shared uint32_t iq3xxs_grid[512];
350+
351+
void init_iq3xxs_shmem()
352+
{
353+
// copy the table into shared memory and sync
354+
if (gl_LocalInvocationIndex.x < 32) {
355+
for (uint i = gl_LocalInvocationIndex.x; i < 512; i += 32) {
356+
iq3xxs_grid[i] = iq3xxs_grid_const[i];
357+
}
358+
}
359+
barrier();
360+
}
361+
362+
#define QUANT_K QUANT_K_IQ3_XXS
363+
#define QUANT_R QUANT_R_IQ3_XXS
364+
#define A_TYPE block_iq3_xxs
365+
#define A_TYPE_PACKED16 block_iq3_xxs_packed16
366+
#endif
367+
297368
#define QUANT_K_IQ3_S 256
298369
#define QUANT_R_IQ3_S 1
299370

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ const std::vector<std::string> type_names = {
5555
"q4_k",
5656
"q5_k",
5757
"q6_k",
58+
"iq3_xxs",
5859
"iq3_s",
5960
"iq4_nl"
6061
};
@@ -312,6 +313,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
312313

313314
for (const auto& tname : type_names) {
314315
if (tname == "iq3_s" && coopmat2) continue;
316+
if (tname == "iq3_xxs" && coopmat2) continue;
315317
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
316318
// For unaligned, load one at a time for f32/f16, or two at a time for quants
317319
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
@@ -367,7 +369,7 @@ void process_shaders() {
367369
if (tname == "f32") {
368370
continue;
369371
}
370-
if (tname == "iq3_s") {
372+
if (tname == "iq3_s" || tname == "iq3_xxs") {
371373
continue;
372374
}
373375

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3919,6 +3919,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39193919
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
39203920
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
39213921
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3922+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ3_XXS,GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
39223923
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ3_S, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
39233924
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
39243925
}

0 commit comments

Comments
 (0)