Skip to content

Commit cd9db2d

Browse files
remyoudomphengjeffbolznv
authored andcommitted
vulkan: implement initial support for IQ2 and IQ3 quantizations (ggml-org#11360)
* vulkan: initial support for IQ3_S * vulkan: initial support for IQ3_XXS * vulkan: initial support for IQ2_XXS * vulkan: initial support for IQ2_XS * vulkan: optimize Q3_K by removing branches * vulkan: implement dequantize variants for coopmat2 * vulkan: initial support for IQ2_S * vulkan: vertically realign code * port failing dequant callbacks from mul_mm * Fix array length mismatches * vulkan: avoid using workgroup size before it is referenced * tests: increase timeout for Vulkan llvmpipe backend --------- Co-authored-by: Jeff Bolz <[email protected]>
1 parent d184068 commit cd9db2d

19 files changed

+1616
-40
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,8 @@ jobs:
346346
id: cmake_test
347347
run: |
348348
cd build
349-
ctest -L main --verbose --timeout 900
349+
# This is using llvmpipe and runs slower than other backends
350+
ctest -L main --verbose --timeout 1800
350351
351352
ubuntu-22-cmake-hip:
352353
runs-on: ubuntu-22.04

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 141 additions & 16 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
1212
#endif
1313

1414
void main() {
15-
#if defined(DATA_A_IQ4_NL)
16-
init_iq4nl_shmem();
15+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
16+
init_iq_shmem(gl_WorkGroupSize);
1717
if (gl_LocalInvocationIndex.x != 0) {
1818
return;
1919
}

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ void quantize(uint dst_idx, uint src_idx)
217217
#endif
218218

219219
void main() {
220-
#if defined(DATA_A_IQ4_NL)
221-
init_iq4nl_shmem();
220+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
221+
init_iq_shmem(gl_WorkGroupSize);
222222
if (gl_LocalInvocationIndex.x != 0) {
223223
return;
224224
}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 217 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,222 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
8888
}
8989
#endif
9090

91+
#if defined(DATA_A_IQ2_XXS)
92+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
93+
const uint ib32 = iqs / 32;
94+
const uint ib8 = (iqs / 8) % 4;
95+
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
96+
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
97+
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
98+
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
99+
const float db = 0.25 * (0.5 + (signs >> 28));
100+
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
101+
// Add parity bit
102+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
103+
const uint sign = sign8 >> (iqs % 8);
104+
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
105+
bool sign0 = (sign & 1) != 0;
106+
bool sign1 = (sign & 2) != 0;
107+
return db * vec2(
108+
grid.x * (sign0 ? -1.0 : 1.0),
109+
grid.y * (sign1 ? -1.0 : 1.0)
110+
);
111+
}
112+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
113+
const uint ib32 = iqs / 32;
114+
const uint ib8 = (iqs / 8) % 4;
115+
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
116+
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
117+
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
118+
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
119+
const float db = 0.25 * (0.5 + (signs >> 28));
120+
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
121+
// Add parity bit
122+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
123+
const uint sign = sign8 >> (iqs % 8);
124+
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
125+
bool sign0 = (sign & 1) != 0;
126+
bool sign1 = (sign & 2) != 0;
127+
bool sign2 = (sign & 4) != 0;
128+
bool sign3 = (sign & 8) != 0;
129+
return db * vec4(
130+
grid.x * (sign0 ? -1.0 : 1.0),
131+
grid.y * (sign1 ? -1.0 : 1.0),
132+
grid.z * (sign2 ? -1.0 : 1.0),
133+
grid.w * (sign3 ? -1.0 : 1.0)
134+
);
135+
}
136+
#endif
137+
138+
#if defined(DATA_A_IQ2_XS)
139+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
140+
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
141+
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
142+
const float db = 0.25 * (0.5 + scale);
143+
const uint sign7 = qs >> 9;
144+
// Add parity bit
145+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
146+
const uint sign = sign8 >> (iqs % 8);
147+
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
148+
bool sign0 = (sign & 1) != 0;
149+
bool sign1 = (sign & 2) != 0;
150+
return db * vec2(
151+
grid.x * (sign0 ? -1.0 : 1.0),
152+
grid.y * (sign1 ? -1.0 : 1.0)
153+
);
154+
}
155+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
156+
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
157+
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
158+
const float db = 0.25 * (0.5 + scale);
159+
const uint sign7 = qs >> 9;
160+
// Add parity bit
161+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
162+
const uint sign = sign8 >> (iqs % 8);
163+
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
164+
bool sign0 = (sign & 1) != 0;
165+
bool sign1 = (sign & 2) != 0;
166+
bool sign2 = (sign & 4) != 0;
167+
bool sign3 = (sign & 8) != 0;
168+
return db * vec4(
169+
grid.x * (sign0 ? -1.0 : 1.0),
170+
grid.y * (sign1 ? -1.0 : 1.0),
171+
grid.z * (sign2 ? -1.0 : 1.0),
172+
grid.w * (sign3 ? -1.0 : 1.0)
173+
);
174+
}
175+
#endif
176+
177+
#if defined(DATA_A_IQ2_S)
178+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
179+
const uint ib32 = iqs / 32;
180+
const uint ib8 = iqs / 8;
181+
182+
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
183+
const uint qs = data_a[a_offset + ib].qs[ib8];
184+
const uint qh = data_a[a_offset + ib].qh[ib32];
185+
const uint qhshift = 2 * (ib8 % 4);
186+
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
187+
188+
const float db = 0.25 * (0.5 + scale);
189+
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
190+
bool sign0 = (sign & 1) != 0;
191+
bool sign1 = (sign & 2) != 0;
192+
return db * vec2(
193+
grid[iqs % 4] * (sign0 ? -1.0 : 1.0),
194+
grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)
195+
);
196+
}
197+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
198+
const uint ib32 = iqs / 32;
199+
const uint ib8 = iqs / 8;
200+
201+
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
202+
const uint qs = data_a[a_offset + ib].qs[ib8];
203+
const uint qh = data_a[a_offset + ib].qh[ib32];
204+
const uint qhshift = 2 * (ib8 % 4);
205+
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
206+
207+
const float db = 0.25 * (0.5 + scale);
208+
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
209+
bool sign0 = (sign & 1) != 0;
210+
bool sign1 = (sign & 2) != 0;
211+
bool sign2 = (sign & 4) != 0;
212+
bool sign3 = (sign & 8) != 0;
213+
return db * vec4(
214+
grid.x * (sign0 ? -1.0 : 1.0),
215+
grid.y * (sign1 ? -1.0 : 1.0),
216+
grid.z * (sign2 ? -1.0 : 1.0),
217+
grid.w * (sign3 ? -1.0 : 1.0)
218+
);
219+
}
220+
#endif
221+
222+
#if defined(DATA_A_IQ3_XXS)
223+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
224+
const uint ib4 = iqs / 4;
225+
const uint ib32 = iqs / 32;
226+
const uint is = QUANT_K / 4 + 4 * ib32;
227+
const uint qs = data_a[a_offset + ib].qs[ib4];
228+
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
229+
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
230+
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
231+
const float db = 0.5 * (0.5 + (signs >> 28));
232+
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
233+
// Add parity bit
234+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
235+
const uint sign = sign8 >> (iqs % 8);
236+
const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));
237+
bool sign0 = (sign & 1) != 0;
238+
bool sign1 = (sign & 2) != 0;
239+
return db * vec2(
240+
grid.x * (sign0 ? -1.0 : 1.0),
241+
grid.y * (sign1 ? -1.0 : 1.0)
242+
);
243+
}
244+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
245+
const uint ib4 = iqs / 4;
246+
const uint ib32 = iqs / 32;
247+
const uint is = QUANT_K / 4 + 4 * ib32;
248+
const uint qs = data_a[a_offset + ib].qs[ib4];
249+
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
250+
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
251+
const float db = 0.5 * (0.5 + (signs >> 28));
252+
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
253+
// Add parity bit
254+
const uint sign8 = sign7 | (bitCount(sign7) << 7);
255+
const uint sign = sign8 >> (iqs % 8);
256+
const u8vec4 grid = unpack8(iq3xxs_grid[qs]);
257+
bool sign0 = (sign & 1) != 0;
258+
bool sign1 = (sign & 2) != 0;
259+
bool sign2 = (sign & 4) != 0;
260+
bool sign3 = (sign & 8) != 0;
261+
return db * vec4(
262+
grid.x * (sign0 ? -1.0 : 1.0),
263+
grid.y * (sign1 ? -1.0 : 1.0),
264+
grid.z * (sign2 ? -1.0 : 1.0),
265+
grid.w * (sign3 ? -1.0 : 1.0)
266+
);
267+
}
268+
#endif
269+
270+
#if defined(DATA_A_IQ3_S)
271+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
272+
const uint qs = data_a[a_offset + ib].qs[iqs / 4];
273+
const uint qh = data_a[a_offset + ib].qh[iqs / 32];
274+
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
275+
const uint scale = data_a[a_offset + ib].scales[iqs / 64];
276+
bool sign0 = (sign & 1) != 0;
277+
bool sign1 = (sign & 2) != 0;
278+
const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);
279+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));
280+
return db * vec2(
281+
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
282+
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)
283+
);
284+
}
285+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
286+
const uint ib4 = iqs / 4;
287+
const uint ib32 = iqs / 32;
288+
const uint qs = data_a[a_offset + ib].qs[ib4];
289+
const uint qh = data_a[a_offset + ib].qh[ib32];
290+
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
291+
const uint scale = data_a[a_offset + ib].scales[ib32 / 2];
292+
bool sign0 = (sign & 1) != 0;
293+
bool sign1 = (sign & 2) != 0;
294+
bool sign2 = (sign & 4) != 0;
295+
bool sign3 = (sign & 8) != 0;
296+
const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);
297+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));
298+
return db * vec4(
299+
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
300+
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),
301+
int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),
302+
int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)
303+
);
304+
}
305+
#endif
306+
91307
#if defined(DATA_A_IQ4_NL)
92308
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
93309
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -105,7 +321,7 @@ vec2 get_dm(uint ib, uint a_offset) {
105321
}
106322
#endif
107323

108-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
324+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
109325
vec2 get_dm(uint ib, uint a_offset) {
110326
return vec2(float(data_a[a_offset + ib].d), 0);
111327
}

0 commit comments

Comments
 (0)