Skip to content

vulkan: optimize iq1 coopmat2 dequant functions #12427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];

const uint ib32 = idx / 32;
const uint ib8 = idx / 8;
const uint ib32 = (idx & 0xE0) >> 5;
const uint ib8 = (idx & 0xF8) >> 3;

const uint qh = bl.block.qh[ib32];
const uint qs = bl.block.qs[ib8];
Expand All @@ -330,14 +330,20 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1
block_iq1_m block;
};

layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
block_iq1_m_packed64 block;
};

float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12;
const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12));
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
const uint idx = coordInBlock[1];

const uint ib8 = idx / 8;
const uint ib16 = idx / 16;
uvec2 scales = unpack32(bl64.block.scales);
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));

const uint ib8 = (idx & 0xF8) >> 3;
const uint ib16 = (idx & 0xF0) >> 4;
const int i8 = int(idx % 8);
const uint sc = bl.block.scales[ib8 / 8];
const uint qs = bl.block.qs[ib8];
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/types.comp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#if !defined(GGML_TYPES_COMP)
#define GGML_TYPES_COMP

#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
Expand Down Expand Up @@ -312,6 +313,12 @@ struct block_iq1_m {
uint16_t scales[QUANT_K_IQ1_M/64];
};

struct block_iq1_m_packed64 {
uint64_t qs[QUANT_K_IQ1_M/8/8];
uint64_t qh[QUANT_K_IQ1_M/16/8];
uint64_t scales;
};

#if defined(DATA_A_IQ1_S)
#define QUANT_K QUANT_K_IQ1_S
#define QUANT_R QUANT_R_IQ1_S
Expand Down
Loading