Skip to content

Commit fa92caa

Browse files
vulkan: initial support for IQ1_S and IQ1_M quantizations
1 parent 98f6b0f commit fa92caa

14 files changed

+537
-28
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 74 additions & 20 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
1212
#endif
1313

1414
void main() {
15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem(gl_WorkGroupSize);
1717
if (gl_LocalInvocationIndex.x != 0) {
1818
return;

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx)
217217
#endif
218218

219219
void main() {
220-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
220+
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
221221
init_iq_shmem(gl_WorkGroupSize);
222222
if (gl_LocalInvocationIndex.x != 0) {
223223
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
22
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
33
#endif
4+
#if defined(DATA_A_IQ1_M)
5+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
6+
#endif
47

58
#include "types.comp"
69

@@ -88,6 +91,83 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
8891
}
8992
#endif
9093

94+
#if defined(DATA_A_IQ1_S)
95+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
96+
const uint ib32 = iqs / 32;
97+
const uint ib8 = iqs / 8;
98+
const int i8 = int(iqs % 8);
99+
const uint qh = data_a[a_offset + ib].qh[ib32];
100+
const uint qs = data_a[a_offset + ib].qs[ib8];
101+
const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1);
102+
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
103+
const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3);
104+
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
105+
// Signed bitfield extract.
106+
const ivec2 gvec = ivec2(
107+
bitfieldExtract(grid, 2 * (i8), 2),
108+
bitfieldExtract(grid, 2 * (i8 + 1), 2)
109+
);
110+
return dl * (vec2(gvec) + delta);
111+
}
112+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
113+
const uint ib32 = iqs / 32;
114+
const uint ib8 = iqs / 8;
115+
const int i8 = int(iqs % 8);
116+
const uint qh = data_a[a_offset + ib].qh[ib32];
117+
const uint qs = data_a[a_offset + ib].qs[ib8];
118+
const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1;
119+
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
120+
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
121+
// Signed bitfield extract.
122+
const ivec4 gvec = ivec4(
123+
bitfieldExtract(grid, 2 * (i8), 2),
124+
bitfieldExtract(grid, 2 * (i8 + 1), 2),
125+
bitfieldExtract(grid, 2 * (i8 + 2), 2),
126+
bitfieldExtract(grid, 2 * (i8 + 3), 2)
127+
);
128+
return dl * (vec4(gvec) + delta);
129+
}
130+
#endif
131+
132+
#if defined(DATA_A_IQ1_M)
133+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
134+
const uint ib8 = iqs / 8;
135+
const uint ib16 = iqs / 16;
136+
const int i8 = int(iqs % 8);
137+
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
138+
const uint qs = data_a[a_offset + ib].qs[ib8];
139+
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
140+
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
141+
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
142+
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
143+
// Signed bitfield extract.
144+
const ivec2 gvec = ivec2(
145+
bitfieldExtract(grid, 2 * (i8), 2),
146+
bitfieldExtract(grid, 2 * (i8 + 1), 2)
147+
);
148+
return dl * (vec2(gvec) + delta);
149+
}
150+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
151+
const uint ib8 = iqs / 8;
152+
const uint ib16 = iqs / 16;
153+
const int i8 = int(iqs % 8);
154+
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
155+
const uint qs = data_a[a_offset + ib].qs[ib8];
156+
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
157+
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
158+
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
159+
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
160+
// Signed bitfield extract.
161+
const ivec4 gvec = ivec4(
162+
bitfieldExtract(grid, 2 * (i8), 2),
163+
bitfieldExtract(grid, 2 * (i8 + 1), 2),
164+
bitfieldExtract(grid, 2 * (i8 + 2), 2),
165+
bitfieldExtract(grid, 2 * (i8 + 3), 2)
166+
);
167+
return dl * (vec4(gvec) + delta);
168+
}
169+
#endif
170+
91171
#if defined(DATA_A_IQ2_XXS)
92172
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
93173
const uint ib32 = iqs / 32;
@@ -357,7 +437,16 @@ vec2 get_dm(uint ib, uint a_offset) {
357437
}
358438
#endif
359439

360-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
440+
#if defined(DATA_A_IQ1_M)
441+
vec2 get_dm(uint ib, uint a_offset) {
442+
const uint16_t[4] scales = data_a[a_offset + ib].scales;
443+
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
444+
const float d = float(uint16BitsToHalf(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)));
445+
return vec2(d, 0);
446+
}
447+
#endif
448+
449+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
361450
vec2 get_dm(uint ib, uint a_offset) {
362451
return vec2(float(data_a[a_offset + ib].d), 0);
363452
}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,56 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2
301301
return ret;
302302
}
303303

304+
#if defined(DATA_A_IQ1_S)
305+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
306+
block_iq1_s block;
307+
};
308+
309+
float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
310+
{
311+
const float16_t d = bl.block.d;
312+
const uint idx = coordInBlock[1];
313+
314+
const uint ib32 = idx / 32;
315+
const uint ib8 = idx / 8;
316+
317+
const uint qh = bl.block.qh[ib32];
318+
const uint qs = bl.block.qs[ib8];
319+
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
320+
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
321+
const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
322+
323+
float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
324+
return ret;
325+
}
326+
#endif
327+
328+
#if defined(DATA_A_IQ1_M)
329+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {
330+
block_iq1_m block;
331+
};
332+
333+
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
334+
{
335+
const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12;
336+
const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12));
337+
const uint idx = coordInBlock[1];
338+
339+
const uint ib8 = idx / 8;
340+
const uint ib16 = idx / 16;
341+
const int i8 = int(idx % 8);
342+
const uint sc = bl.block.scales[ib8 / 8];
343+
const uint qs = bl.block.qs[ib8];
344+
const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
345+
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
346+
const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
347+
const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];
348+
349+
float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
350+
return ret;
351+
}
352+
#endif
353+
304354
#if defined(DATA_A_IQ2_XXS)
305355
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
306356
block_iq2_xxs block;
@@ -512,6 +562,10 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
512562
#define dequantFuncA dequantFuncQ5_K
513563
#elif defined(DATA_A_Q6_K)
514564
#define dequantFuncA dequantFuncQ6_K
565+
#elif defined(DATA_A_IQ1_S)
566+
#define dequantFuncA dequantFuncIQ1_S
567+
#elif defined(DATA_A_IQ1_M)
568+
#define dequantFuncA dequantFuncIQ1_M
515569
#elif defined(DATA_A_IQ2_XXS)
516570
#define dequantFuncA dequantFuncIQ2_XXS
517571
#elif defined(DATA_A_IQ2_XS)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
4+
5+
#include "dequant_head.comp"
6+
7+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
8+
9+
layout (binding = 0) readonly buffer A {block_iq1_m data_a[];};
10+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
11+
12+
void main() {
13+
// Each thread handles 1 subblock (32 values with 2 scales)
14+
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
15+
16+
init_iq_shmem(gl_WorkGroupSize);
17+
18+
if (ib >= p.nel / 256) {
19+
return;
20+
}
21+
22+
const uint ib32 = gl_LocalInvocationID.x % 8;
23+
const uint ib64 = ib32 / 2;
24+
const uint b_idx = 256 * ib + 32 * ib32;
25+
26+
const uint16_t[4] scales = data_a[ib].scales;
27+
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
28+
const float d = float(uint16BitsToHalf(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)));
29+
30+
const uint sc = data_a[ib].scales[ib64];
31+
[[unroll]] for (int l = 0; l < 4; ++l) {
32+
const uint ib16 = 2 * ib32 + l / 2;
33+
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
34+
const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1));
35+
const uint qs = data_a[ib].qs[4 * ib32 + l];
36+
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
37+
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
38+
[[unroll]] for (int j = 0; j < 8; ++j) {
39+
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
40+
}
41+
}
42+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#version 450
2+
3+
#include "dequant_head.comp"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_iq1_s data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
// Each thread handles 1 subblock (32 values with 2 scales)
12+
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
13+
14+
init_iq_shmem(gl_WorkGroupSize);
15+
16+
if (ib >= p.nel / 256) {
17+
return;
18+
}
19+
20+
const uint ib32 = gl_LocalInvocationID.x % 8;
21+
const uint b_idx = 256 * ib + 32 * ib32;
22+
23+
uint qh = data_a[ib].qh[ib32];
24+
const float d = float(data_a[ib].d);
25+
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
26+
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
27+
[[unroll]] for (uint l = 0; l < 4; ++l) {
28+
const uint qs = data_a[ib].qs[4 * ib32 + l];
29+
const uint hi = bitfieldExtract(qh, 3 * int(l), 3);
30+
const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]);
31+
[[unroll]] for (int j = 0; j < 8; ++j) {
32+
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
33+
}
34+
}
35+
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
104104
#endif
105105

106106
void main() {
107-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
107+
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
108108
init_iq_shmem(gl_WorkGroupSize);
109109
#endif
110110

ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void main() {
1212
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
1313
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
1414

15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem(gl_WorkGroupSize);
1717
#endif
1818

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
133133
void main() {
134134
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
135135

136-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
136+
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
137137
init_iq_shmem(gl_WorkGroupSize);
138138
#endif
139139

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#ifdef FLOAT16
77
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
88
#endif
9+
#if defined(DATA_A_IQ1_M)
10+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
11+
#endif
912

1013
#ifdef COOPMAT
1114
#extension GL_KHR_cooperative_matrix : enable
@@ -95,7 +98,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
9598
#endif
9699

97100
void main() {
98-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
101+
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
99102
init_iq_shmem(gl_WorkGroupSize);
100103
#endif
101104

@@ -437,6 +440,56 @@ void main() {
437440

438441
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
439442
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
443+
#elif defined(DATA_A_IQ1_S)
444+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
445+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
446+
447+
const uint ib = idx / 128; // 2 values per idx
448+
const uint ib32 = (idx % 128) / 16; // 0..7
449+
const uint ib8 = (idx % 128) / 4;
450+
const int i8 = 2 * int(idx % 4);
451+
452+
const float d = float(data_a[ib].d);
453+
const uint qh = data_a[ib].qh[ib32];
454+
const uint qs = data_a[ib].qs[ib8];
455+
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
456+
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
457+
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
458+
459+
const ivec2 gvec = ivec2(
460+
bitfieldExtract(grid, 2 * (i8), 2),
461+
bitfieldExtract(grid, 2 * (i8 + 1), 2)
462+
);
463+
const vec2 v = dl * (vec2(gvec) + delta);
464+
465+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
466+
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
467+
#elif defined(DATA_A_IQ1_M)
468+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
469+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
470+
471+
const uint ib = idx / 128; // 2 values per idx
472+
const uint ib8 = (idx % 128) / 4;
473+
const uint ib16 = ib8 / 2;
474+
const int i8 = 2 * int(idx % 4);
475+
476+
const uint16_t[4] scales = data_a[ib].scales;
477+
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
478+
const float d = float(uint16BitsToHalf(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)));
479+
const uint sc = scales[ib8 / 8];
480+
const uint qs = data_a[ib].qs[ib8];
481+
const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
482+
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
483+
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
484+
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
485+
const ivec2 gvec = ivec2(
486+
bitfieldExtract(grid, 2 * (i8), 2),
487+
bitfieldExtract(grid, 2 * (i8 + 1), 2)
488+
);
489+
const vec2 v = dl * (vec2(gvec) + delta);
490+
491+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
492+
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
440493
#elif defined(DATA_A_IQ2_XXS)
441494
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
442495
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
106106
#endif
107107

108108
void main() {
109-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
109+
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
110110
init_iq_shmem(gl_WorkGroupSize);
111111
#endif
112112

0 commit comments

Comments
 (0)