Skip to content

vulkan: optimize coopmat2 q4_k/q5_k dequant functions. #11206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 47 additions & 33 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -157,39 +157,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
block_q4_K_packed16 block;
};

layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
block_q4_K_packed128 block;
};

float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
const uint idx = coordInBlock[1];

const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7

const f16vec2 loadd = bl.block.d;
uvec4 v = bl128.block.q4k[0];

const f16vec2 loadd = unpackFloat2x16(v.x);

uint32_t sc;
uint32_t mbyte;

uint32_t scidx0 = (is < 4) ? is : (is + 4);
uint32_t scidx1 = (is < 4) ? is : (is - 4);
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;

sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);

sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;

const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);

uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs)[idx & 1];
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;

float16_t ret = d * float16_t(qs) - m;

Expand All @@ -204,47 +212,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
block_q5_K_packed16 block;
};

layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
block_q5_K_packed128 block;
};

float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
const uint idx = coordInBlock[1];

const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7

const uint32_t hm = 0x0101 << is;
uvec4 v = bl128.block.q5k[0];

const f16vec2 loadd = bl.block.d;
const f16vec2 loadd = unpackFloat2x16(v.x);

uint32_t sc;
uint32_t mbyte;

uint32_t scidx0 = (is < 4) ? is : (is + 4);
uint32_t scidx1 = (is < 4) ? is : (is - 4);
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;

sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);

sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;

const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);

uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = qh & hm;
qh = unpack8(qh)[idx & 1];
qh = ((qh >> is) & 0x101) << 4;

uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs)[idx & 1];
qs = unpack8(qs | qh)[idx & 1];

float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m;
float16_t ret = d * (float16_t(qs)) - m;

return ret;
}
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/types.comp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ struct block_q4_K_packed32
uint32_t qs[QUANT_K_Q4_K/2/4];
};

struct block_q4_K_packed128
{
uvec4 q4k[9];
};

#if defined(DATA_A_Q4_K)
#define QUANT_K QUANT_K_Q4_K
#define A_TYPE block_q4_K
Expand All @@ -252,6 +257,11 @@ struct block_q5_K_packed16
uint16_t qs[QUANT_K_Q5_K/2/2];
};

struct block_q5_K_packed128
{
uvec4 q5k[11];
};

#if defined(DATA_A_Q5_K)
#define QUANT_K QUANT_K_Q5_K
#define A_TYPE block_q5_K
Expand Down
Loading