Skip to content

Commit 042b334

Browse files
vulkan: initial support for IQ2_XS
1 parent a0b2015 commit 042b334

13 files changed

+308
-9
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 25 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
1212
#endif
1313

1414
void main() {
15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem();
1717
if (gl_LocalInvocationIndex.x != 0) {
1818
return;

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx)
217217
#endif
218218

219219
void main() {
220-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
220+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
221221
init_iq_shmem();
222222
if (gl_LocalInvocationIndex.x != 0) {
223223
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,45 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
135135
}
136136
#endif
137137

138+
#if defined(DATA_A_IQ2_XS)
139+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
140+
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
141+
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
142+
const float db = 0.25 * (0.5 + scale);
143+
const uint sign7 = qs >> 9;
144+
// Add parity bit
145+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
146+
const uint sign = sign8 >> (iqs % 8);
147+
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
148+
bool sign0 = (sign & 1) != 0;
149+
bool sign1 = (sign & 2) != 0;
150+
return db * vec2(
151+
grid.x * (sign0 ? -1.0 : 1.0),
152+
grid.y * (sign1 ? -1.0 : 1.0)
153+
);
154+
}
155+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
156+
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
157+
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
158+
const float db = 0.25 * (0.5 + scale);
159+
const uint sign7 = qs >> 9;
160+
// Add parity bit
161+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
162+
const uint sign = sign8 >> (iqs % 8);
163+
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
164+
bool sign0 = (sign & 1) != 0;
165+
bool sign1 = (sign & 2) != 0;
166+
bool sign2 = (sign & 4) != 0;
167+
bool sign3 = (sign & 8) != 0;
168+
return db * vec4(
169+
grid.x * (sign0 ? -1.0 : 1.0),
170+
grid.y * (sign1 ? -1.0 : 1.0),
171+
grid.z * (sign2 ? -1.0 : 1.0),
172+
grid.w * (sign3 ? -1.0 : 1.0)
173+
);
174+
}
175+
#endif
176+
138177
#if defined(DATA_A_IQ3_XXS)
139178
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
140179
const uint ib4 = iqs / 4;
@@ -237,7 +276,7 @@ vec2 get_dm(uint ib, uint a_offset) {
237276
}
238277
#endif
239278

240-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
279+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
241280
vec2 get_dm(uint ib, uint a_offset) {
242281
return vec2(float(data_a[a_offset + ib].d), 0);
243282
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#version 450
2+
3+
#include "dequant_head.comp"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
// Each thread handles 1 subblock (32 values with 2 scales)
12+
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
13+
14+
init_iq_shmem();
15+
16+
if (ib >= p.nel / 256) {
17+
return;
18+
}
19+
20+
const uint ib32 = gl_LocalInvocationID.x % 8;
21+
const uint b_idx = 256 * ib + 32 * ib32;
22+
23+
const float d = float(data_a[ib].d);
24+
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
25+
const vec2 db = d * (0.5 + scale) * 0.25;
26+
27+
[[unroll]] for (uint l = 0; l < 4; ++l) {
28+
uint16_t qs = data_a[ib].qs[4 * ib32 + l];
29+
const uint sign7 = qs >> 9;
30+
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
31+
const uvec2 grid = iq2xs_grid[qs & 511];
32+
const u8vec4 grid0 = unpack8(grid.x);
33+
const u8vec4 grid1 = unpack8(grid.y);
34+
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
35+
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
36+
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
37+
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
38+
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
39+
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
40+
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
41+
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
42+
}
43+
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
104104
#endif
105105

106106
void main() {
107-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
107+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
108108
init_iq_shmem();
109109
#endif
110110

ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void main() {
1212
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
1313
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
1414

15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem();
1717
#endif
1818

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
133133
void main() {
134134
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
135135

136-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
136+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
137137
init_iq_shmem();
138138
#endif
139139

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
9595
#endif
9696

9797
void main() {
98-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
98+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
9999
init_iq_shmem();
100100
#endif
101101

@@ -462,6 +462,26 @@ void main() {
462462
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
463463
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
464464

465+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
466+
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
467+
#elif defined(DATA_A_IQ2_XS)
468+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
469+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
470+
471+
const uint ib = idx / 128; // 2 values per idx
472+
const uint ib32 = (idx % 128) / 16; // 0..7
473+
const uint ib8 = (idx / 4) % 4; // 0..3
474+
475+
const float d = float(data_a[ib].d);
476+
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
477+
const float db = d * 0.25 * (0.5 + scale);
478+
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
479+
const uint sign7 = qs >> 9;
480+
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
481+
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
482+
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
483+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
484+
465485
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
466486
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
467487
#elif defined(DATA_A_IQ3_XXS)

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
106106
#endif
107107

108108
void main() {
109-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
109+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
110110
init_iq_shmem();
111111
#endif
112112

0 commit comments

Comments
 (0)