Skip to content

Commit 15d6b15

Browse files
vulkan: initial support for IQ2_S
1 parent 02214b6 commit 15d6b15

14 files changed

+491
-34
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 50 additions & 25 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
1212
#endif
1313

1414
void main() {
15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem();
1717
if (gl_LocalInvocationIndex.x != 0) {
1818
return;

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx)
217217
#endif
218218

219219
void main() {
220-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
220+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
221221
init_iq_shmem();
222222
if (gl_LocalInvocationIndex.x != 0) {
223223
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,51 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
174174
}
175175
#endif
176176

177+
#if defined(DATA_A_IQ2_S)
178+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
179+
const uint ib32 = iqs / 32;
180+
const uint ib8 = iqs / 8;
181+
182+
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
183+
const uint qs = data_a[a_offset + ib].qs[ib8];
184+
const uint qh = data_a[a_offset + ib].qh[ib32];
185+
const uint qhshift = 2 * (ib8 % 4);
186+
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
187+
188+
const float db = 0.25 * (0.5 + scale);
189+
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
190+
bool sign0 = (sign & 1) != 0;
191+
bool sign1 = (sign & 2) != 0;
192+
return db * vec2(
193+
grid[iqs % 4] * (sign0 ? -1.0 : 1.0),
194+
grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)
195+
);
196+
}
197+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
198+
const uint ib32 = iqs / 32;
199+
const uint ib8 = iqs / 8;
200+
201+
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
202+
const uint qs = data_a[a_offset + ib].qs[ib8];
203+
const uint qh = data_a[a_offset + ib].qh[ib32];
204+
const uint qhshift = 2 * (ib8 % 4);
205+
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
206+
207+
const float db = 0.25 * (0.5 + scale);
208+
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
209+
bool sign0 = (sign & 1) != 0;
210+
bool sign1 = (sign & 2) != 0;
211+
bool sign2 = (sign & 4) != 0;
212+
bool sign3 = (sign & 8) != 0;
213+
return db * vec4(
214+
grid.x * (sign0 ? -1.0 : 1.0),
215+
grid.y * (sign1 ? -1.0 : 1.0),
216+
grid.z * (sign2 ? -1.0 : 1.0),
217+
grid.w * (sign3 ? -1.0 : 1.0)
218+
);
219+
}
220+
#endif
221+
177222
#if defined(DATA_A_IQ3_XXS)
178223
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
179224
const uint ib4 = iqs / 4;
@@ -276,7 +321,7 @@ vec2 get_dm(uint ib, uint a_offset) {
276321
}
277322
#endif
278323

279-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
324+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
280325
vec2 get_dm(uint ib, uint a_offset) {
281326
return vec2(float(data_a[a_offset + ib].d), 0);
282327
}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,33 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor
361361
}
362362
#endif
363363

364+
#if defined(DATA_A_IQ2_S)
365+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {
366+
block_iq2_s block;
367+
};
368+
369+
float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
370+
{
371+
const float16_t d = bl.block.d;
372+
const uint idx = coordInBlock[1];
373+
374+
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
375+
const uint ib8 = (idx & 0xF8) >> 3; // 0..31
376+
const uint qhshift = 2 * (ib8 % 4);
377+
378+
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 4)) & 0xf;
379+
const uint qs = bl.block.qs[ib8];
380+
const uint qh = bl.block.qh[ib32];
381+
const uint sign = bl.block.qs[QUANT_K / 8 + ib8];
382+
383+
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(scale));
384+
const uint8_t g = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2])[idx & 3];
385+
386+
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
387+
return ret;
388+
}
389+
#endif
390+
364391
#if defined(DATA_A_IQ3_XXS)
365392
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
366393
block_iq3_xxs block;
@@ -468,6 +495,8 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
468495
#define dequantFuncA dequantFuncIQ2_XXS
469496
#elif defined(DATA_A_IQ2_XS)
470497
#define dequantFuncA dequantFuncIQ2_XS
498+
#elif defined(DATA_A_IQ2_S)
499+
#define dequantFuncA dequantFuncIQ2_S
471500
#elif defined(DATA_A_IQ3_XXS)
472501
#define dequantFuncA dequantFuncIQ3_XXS
473502
#elif defined(DATA_A_IQ3_S)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#version 450
2+
3+
#include "dequant_head.comp"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_iq2_s data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
// Each thread handles 1 subblock (32 values with 2 scales)
12+
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
13+
14+
init_iq_shmem();
15+
16+
if (ib >= p.nel / 256) {
17+
return;
18+
}
19+
20+
const uint ib32 = gl_LocalInvocationID.x % 8;
21+
const uint b_idx = 256 * ib + 32 * ib32;
22+
23+
const float d = float(data_a[ib].d);
24+
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
25+
const vec2 db = d * (0.5 + scale) * 0.25;
26+
27+
uint qh = data_a[ib].qh[ib32];
28+
[[unroll]] for (uint l = 0; l < 4; ++l) {
29+
uint qs = data_a[ib].qs[4 * ib32 + l];
30+
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
31+
qs |= (qh << (8 - 2 * l)) & 0x300;
32+
const uvec2 grid = iq2s_grid[qs & 511];
33+
const u8vec4 grid0 = unpack8(grid.x);
34+
const u8vec4 grid1 = unpack8(grid.y);
35+
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
36+
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0));
37+
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0));
38+
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0));
39+
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0));
40+
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0));
41+
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0));
42+
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0));
43+
}
44+
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
104104
#endif
105105

106106
void main() {
107-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
107+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
108108
init_iq_shmem();
109109
#endif
110110

ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void main() {
1212
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
1313
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
1414

15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem();
1717
#endif
1818

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
133133
void main() {
134134
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
135135

136-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
136+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
137137
init_iq_shmem();
138138
#endif
139139

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
9595
#endif
9696

9797
void main() {
98-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
98+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
9999
init_iq_shmem();
100100
#endif
101101

@@ -480,6 +480,28 @@ void main() {
480480
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
481481
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
482482

483+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
484+
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
485+
#elif defined(DATA_A_IQ2_S)
486+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
487+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
488+
489+
const uint ib = idx / 128; // 2 values per idx
490+
const uint ib8 = (idx % 128) / 4; // 0..31
491+
const uint ib32 = ib8 / 4; // 0..7
492+
493+
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
494+
const uint qs = data_a[ib].qs[ib8];
495+
const uint qh = data_a[ib].qh[ib32];
496+
const uint qhshift = 2 * (ib8 % 4);
497+
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
498+
499+
const float d = float(data_a[ib].d);
500+
const float db = d * 0.25 * (0.5 + scale);
501+
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
502+
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
503+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
504+
483505
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
484506
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
485507
#elif defined(DATA_A_IQ3_XXS)

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
106106
#endif
107107

108108
void main() {
109-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
109+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
110110
init_iq_shmem();
111111
#endif
112112

0 commit comments

Comments
 (0)