Skip to content

vulkan: implement initial support for IQ2 and IQ3 quantizations #11360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ jobs:
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 1800

ubuntu-22-cmake-hip:
runs-on: ubuntu-22.04
Expand Down
157 changes: 141 additions & 16 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif

void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;
}
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ void quantize(uint dst_idx, uint src_idx)
#endif

void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;
}
Expand Down
218 changes: 217 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,222 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif

#if defined(DATA_A_IQ2_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif

#if defined(DATA_A_IQ2_XS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif

#if defined(DATA_A_IQ2_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;

const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);

const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid[iqs % 4] * (sign0 ? -1.0 : 1.0),
grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;

const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);

const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif

#if defined(DATA_A_IQ3_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif

#if defined(DATA_A_IQ3_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint qs = data_a[a_offset + ib].qs[iqs / 4];
const uint qh = data_a[a_offset + ib].qh[iqs / 32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[iqs / 64];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));
return db * vec2(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[ib32 / 2];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));
return db * vec4(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),
int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),
int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)
);
}
#endif

#if defined(DATA_A_IQ4_NL)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
Expand All @@ -105,7 +321,7 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif

#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}
Expand Down
Loading
Loading