Skip to content

Commit 26bcc7f

Browse files
vulkan: implement dequantize variants for coopmat2
1 parent a408b4b commit 26bcc7f

File tree

3 files changed

+140
-15
lines changed

3 files changed

+140
-15
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,10 +1627,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
16271627
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
16281628
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
16291629
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1630-
//CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1631-
//CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1632-
//CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1633-
//CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1630+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1631+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1632+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1633+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
16341634
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
16351635

16361636
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
@@ -1644,10 +1644,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
16441644
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
16451645
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
16461646
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1647-
//CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1648-
//CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1649-
//CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1650-
//CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1647+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1648+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1649+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1650+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
16511651
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
16521652
#undef CREATE_MM
16531653
#undef CREATE_MM2

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,130 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2
301301
return ret;
302302
}
303303

304+
#if defined(DATA_A_IQ2_XXS)
305+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
306+
block_iq2_xxs block;
307+
};
308+
309+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {
310+
block_iq2_xxs_packed16 block;
311+
};
312+
313+
float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
314+
{
315+
decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
316+
const float16_t d = bl.block.d;
317+
const uint idx = coordInBlock[1];
318+
319+
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
320+
const uint ib8 = (idx & 0x18) >> 3; // 0..3
321+
const uint iqs = 8 * ib32 + ib8;
322+
323+
const uint8_t qs = bl.block.qs[iqs];
324+
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
325+
326+
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28));
327+
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
328+
sign |= bitCount(sign) << 7;
329+
330+
const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3];
331+
332+
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
333+
334+
return ret;
335+
}
336+
#endif
337+
338+
#if defined(DATA_A_IQ2_XS)
339+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {
340+
block_iq2_xs block;
341+
};
342+
343+
float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
344+
{
345+
const float16_t d = bl.block.d;
346+
const uint idx = coordInBlock[1];
347+
348+
const uint is = (idx & 0xE0) >> 5; // 0..8
349+
const uint sshift = (idx & 0x10) >> 2; // 0,4
350+
const uint iqs = (idx & 0xF8) >> 3; // 0..63
351+
352+
const uint16_t qs = bl.block.qs[iqs];
353+
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF));
354+
355+
uint sign = uint(qs >> 9);
356+
sign |= bitCount(sign) << 7;
357+
const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3];
358+
359+
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
360+
return ret;
361+
}
362+
#endif
363+
364+
#if defined(DATA_A_IQ3_XXS)
365+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
366+
block_iq3_xxs block;
367+
};
368+
369+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {
370+
block_iq3_xxs_packed16 block;
371+
};
372+
373+
float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
374+
{
375+
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
376+
const float16_t d = bl.block.d;
377+
const uint idx = coordInBlock[1];
378+
379+
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
380+
const uint ib4 = (idx & 0xFC) >> 4; // 0..63
381+
const uint is16 = QUANT_K / 8 + 2 * ib32; // index in packed16
382+
383+
const uint8_t qs = bl.block.qs[ib4];
384+
const uint signscale = pack32(u16vec2(bl16.block.qs[is16], bl16.block.qs[is16+1]));
385+
386+
const float16_t dscale = bl.block.d * 0.5hf * (0.5hf + float16_t(signscale >> 28));
387+
uint sign = bitfieldExtract(signscale, 7 * int(ib4 & 3), 7);
388+
sign |= bitCount(sign) << 7;
389+
390+
const uint8_t g = unpack8(iq3xxs_grid[qs])[idx & 3];
391+
392+
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
393+
return ret;
394+
}
395+
#endif
396+
397+
#if defined(DATA_A_IQ3_S)
398+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {
399+
block_iq3_s block;
400+
};
401+
402+
float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
403+
{
404+
const float16_t d = bl.block.d;
405+
const uint idx = coordInBlock[1];
406+
407+
const uint iqs = (idx & 0xFC) >> 2; // 0..63
408+
const uint iqh = (idx & 0xE0) >> 5; // 0..7
409+
const uint qhbit = iqs & 7;
410+
const uint isgn = (idx & 0xF8) >> 3; // 0..31
411+
const uint is = (idx & 0xC0) >> 6; // 0..3
412+
413+
const uint8_t scale = (bl.block.scales[is] >> ((idx & 0x20) >> 3)) & uint8_t(0xF);
414+
const float16_t dscale = d * (1.0hf + float16_t(2 * scale));
415+
416+
const uint qs = bl.block.qs[iqs];
417+
const uint qh = (bl.block.qh[iqh] << (8 - qhbit)) & 0x100;
418+
const uint8_t sign = bl.block.signs[isgn];
419+
420+
const uint g = unpack8(iq3s_grid[qs | qh])[idx & 3];
421+
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
422+
423+
return ret;
424+
}
425+
#endif
426+
427+
304428
#if defined(DATA_A_IQ4_NL)
305429
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
306430
block_iq4_nl block;
@@ -340,6 +464,14 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
340464
#define dequantFuncA dequantFuncQ5_K
341465
#elif defined(DATA_A_Q6_K)
342466
#define dequantFuncA dequantFuncQ6_K
467+
#elif defined(DATA_A_IQ2_XXS)
468+
#define dequantFuncA dequantFuncIQ2_XXS
469+
#elif defined(DATA_A_IQ2_XS)
470+
#define dequantFuncA dequantFuncIQ2_XS
471+
#elif defined(DATA_A_IQ3_XXS)
472+
#define dequantFuncA dequantFuncIQ3_XXS
473+
#elif defined(DATA_A_IQ3_S)
474+
#define dequantFuncA dequantFuncIQ3_S
343475
#elif defined(DATA_A_IQ4_NL)
344476
#define dequantFuncA dequantFuncIQ4_NL
345477
#endif

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
314314
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
315315

316316
for (const auto& tname : type_names) {
317-
if (tname == "iq2_xs" && coopmat2) continue;
318-
if (tname == "iq2_xxs" && coopmat2) continue;
319-
if (tname == "iq3_xxs" && coopmat2) continue;
320-
if (tname == "iq3_s" && coopmat2) continue;
321317
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
322318
// For unaligned, load one at a time for f32/f16, or two at a time for quants
323319
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
@@ -373,9 +369,6 @@ void process_shaders() {
373369
if (tname == "f32") {
374370
continue;
375371
}
376-
if (tname == "iq3_s" || tname == "iq3_xxs" || tname == "iq2_xs" || tname == "iq2_xxs") {
377-
continue;
378-
}
379372

380373
if (tname == "f16") {
381374
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",

0 commit comments

Comments
 (0)