Skip to content

Commit 3d2c7a0

Browse files
vulkan: initial support for IQ3_S
1 parent 05f63cc commit 3d2c7a0

File tree

11 files changed

+252
-6
lines changed

11 files changed

+252
-6
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 25 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,43 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
8888
}
8989
#endif
9090

91+
#if defined(DATA_A_IQ3_S)
92+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
93+
const uint qs = data_a[a_offset + ib].qs[iqs / 4];
94+
const uint qh = data_a[a_offset + ib].qh[iqs / 32];
95+
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
96+
const uint scale = data_a[a_offset + ib].scales[iqs / 64];
97+
bool sign0 = (sign & 1) != 0;
98+
bool sign1 = (sign & 2) != 0;
99+
const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);
100+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));
101+
return db * vec2(
102+
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
103+
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)
104+
);
105+
}
106+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
107+
const uint ib4 = iqs / 4;
108+
const uint ib32 = iqs / 32;
109+
const uint qs = data_a[a_offset + ib].qs[ib4];
110+
const uint qh = data_a[a_offset + ib].qh[ib32];
111+
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
112+
const uint scale = data_a[a_offset + ib].scales[ib32 / 2];
113+
bool sign0 = (sign & 1) != 0;
114+
bool sign1 = (sign & 2) != 0;
115+
bool sign2 = (sign & 4) != 0;
116+
bool sign3 = (sign & 8) != 0;
117+
const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);
118+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));
119+
return db * vec4(
120+
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
121+
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),
122+
int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),
123+
int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)
124+
);
125+
}
126+
#endif
127+
91128
#if defined(DATA_A_IQ4_NL)
92129
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
93130
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -105,7 +142,7 @@ vec2 get_dm(uint ib, uint a_offset) {
105142
}
106143
#endif
107144

108-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
145+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
109146
vec2 get_dm(uint ib, uint a_offset) {
110147
return vec2(float(data_a[a_offset + ib].d), 0);
111148
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#version 450
2+
3+
#include "dequant_head.comp"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_iq3_s data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
// Each thread handles 1 scale nibble.
12+
// Each block contains 4 scale bytes (8 scales) for 256 output values.
13+
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
14+
15+
init_iq3s_shmem();
16+
17+
if (ib >= p.nel / 256) {
18+
return;
19+
}
20+
21+
const uint is = gl_LocalInvocationID.x % 8;
22+
const uint b_idx = 256 * ib + 32 * is;
23+
24+
const float d = float(data_a[ib].d);
25+
const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf));
26+
27+
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
28+
uint qh = data_a[ib].qh[is];
29+
[[unroll]] for (uint l = 0; l < 8; ++l) {
30+
uint qs = data_a[ib].qs[8 * is + l];
31+
uint gidx = qs | ((qh << (8 - l)) & 256);
32+
uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1));
33+
u8vec4 grid = unpack8(iq3s_grid[gidx]);
34+
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
35+
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
36+
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
37+
data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0));
38+
}
39+
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
104104
#endif
105105

106106
void main() {
107-
#if defined(DATA_A_IQ4_NL)
107+
#if defined(DATA_A_IQ3_S)
108+
init_iq3s_shmem();
109+
#elif defined(DATA_A_IQ4_NL)
108110
init_iq4nl_shmem();
109111
#endif
110112

ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ void main() {
1212
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
1313
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
1414

15-
#if defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ3_S)
16+
init_iq3s_shmem();
17+
#elif defined(DATA_A_IQ4_NL)
1618
init_iq4nl_shmem();
1719
#endif
1820

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
133133
void main() {
134134
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
135135

136-
#if defined(DATA_A_IQ4_NL)
136+
#if defined(DATA_A_IQ3_S)
137+
init_iq3s_shmem();
138+
#elif defined(DATA_A_IQ4_NL)
137139
init_iq4nl_shmem();
138140
#endif
139141

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
9595
#endif
9696

9797
void main() {
98-
#if defined(DATA_A_IQ4_NL)
98+
#if defined(DATA_A_IQ3_S)
99+
init_iq3s_shmem();
100+
#elif defined(DATA_A_IQ4_NL)
99101
init_iq4nl_shmem();
100102
#endif
101103

@@ -439,6 +441,26 @@ void main() {
439441

440442
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
441443
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
444+
#elif defined(DATA_A_IQ3_S)
445+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
446+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
447+
448+
const uint ib = idx / 128; // 2 values per idx
449+
const uint iqs = (idx % 128) / 2; // 0..63
450+
const uint iqh = iqs / 8;
451+
452+
const float d = float(data_a[ib].d);
453+
const uint qs = data_a[ib].qs[iqs];
454+
const uint qh = data_a[ib].qh[iqh];
455+
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
456+
const uint scale = data_a[ib].scales[iqs / 16];
457+
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
458+
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
459+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
460+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
461+
462+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
463+
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
442464
#elif defined(DATA_A_IQ4_NL)
443465
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
444466
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
106106
#endif
107107

108108
void main() {
109-
#if defined(DATA_A_IQ4_NL)
109+
#if defined(DATA_A_IQ3_S)
110+
init_iq3s_shmem();
111+
#elif defined(DATA_A_IQ4_NL)
110112
init_iq4nl_shmem();
111113
#endif
112114

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,115 @@ struct block_q6_K_packed16
294294

295295
// IQuants
296296

297+
#define QUANT_K_IQ3_S 256
298+
#define QUANT_R_IQ3_S 1
299+
300+
struct block_iq3_s
301+
{
302+
float16_t d;
303+
uint8_t qs[QUANT_K_IQ3_S/4];
304+
uint8_t qh[QUANT_K_IQ3_S/32];
305+
uint8_t signs[QUANT_K_IQ3_S/8];
306+
uint8_t scales[QUANT_K_IQ3_S/64];
307+
};
308+
309+
struct block_iq3_s_packed16
310+
{
311+
float16_t d;
312+
uint16_t qs[QUANT_K_IQ3_S/4/2];
313+
uint16_t qh[QUANT_K_IQ3_S/32/2];
314+
uint16_t signs[QUANT_K_IQ3_S/8/2];
315+
uint16_t scales[QUANT_K_IQ3_S/64/2];
316+
};
317+
318+
#if defined(DATA_A_IQ3_S)
319+
320+
const uint32_t iq3s_grid_const[512] = {
321+
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
322+
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
323+
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
324+
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
325+
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
326+
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
327+
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
328+
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
329+
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
330+
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
331+
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
332+
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
333+
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
334+
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
335+
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
336+
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
337+
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
338+
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
339+
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
340+
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
341+
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
342+
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
343+
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
344+
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
345+
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
346+
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
347+
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
348+
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
349+
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
350+
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
351+
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
352+
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
353+
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
354+
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
355+
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
356+
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
357+
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
358+
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
359+
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
360+
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
361+
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
362+
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
363+
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
364+
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
365+
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
366+
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
367+
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
368+
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
369+
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
370+
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
371+
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
372+
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
373+
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
374+
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
375+
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
376+
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
377+
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
378+
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
379+
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
380+
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
381+
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
382+
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
383+
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
384+
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
385+
};
386+
387+
shared uint32_t iq3s_grid[512];
388+
389+
void init_iq3s_shmem()
390+
{
391+
// copy the table into shared memory and sync
392+
if (gl_LocalInvocationIndex.x < 32) {
393+
for (uint i = gl_LocalInvocationIndex.x; i < 512; i += 32) {
394+
iq3s_grid[i] = iq3s_grid_const[i];
395+
}
396+
}
397+
barrier();
398+
}
399+
400+
#define QUANT_K QUANT_K_IQ3_S
401+
#define QUANT_R QUANT_R_IQ3_S
402+
#define A_TYPE block_iq3_s
403+
#define A_TYPE_PACKED16 block_iq3_s_packed16
404+
#endif
405+
297406
#define QUANT_K_IQ4_NL 32
298407
#define QUANT_R_IQ4_NL 2
299408

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ const std::vector<std::string> type_names = {
5555
"q4_k",
5656
"q5_k",
5757
"q6_k",
58+
"iq3_s",
5859
"iq4_nl"
5960
};
6061

@@ -310,6 +311,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
310311
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
311312

312313
for (const auto& tname : type_names) {
314+
if (tname == "iq3_s" && coopmat2) continue;
313315
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
314316
// For unaligned, load one at a time for f32/f16, or two at a time for quants
315317
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
@@ -365,6 +367,9 @@ void process_shaders() {
365367
if (tname == "f32") {
366368
continue;
367369
}
370+
if (tname == "iq3_s") {
371+
continue;
372+
}
368373

369374
if (tname == "f16") {
370375
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3919,6 +3919,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39193919
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
39203920
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
39213921
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3922+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ3_S, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
39223923
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
39233924
}
39243925

0 commit comments

Comments
 (0)