Skip to content

Commit 078ebe5

Browse files
committed
port failing dequant callbacks from mul_mm
1 parent a56c535 commit 078ebe5

File tree

1 file changed

+50
-47
lines changed

1 file changed

+50
-47
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -368,23 +368,25 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2
368368

369369
float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
370370
{
371-
const float16_t d = bl.block.d;
372-
const uint idx = coordInBlock[1];
371+
uint idx = coordInBlock[1];
372+
uint lsb = idx & 1;
373+
idx /= 2;
373374

374-
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
375-
const uint ib8 = (idx & 0xF8) >> 3; // 0..31
376-
const uint qhshift = 2 * (ib8 % 4);
375+
const uint ib8 = (idx % 128) / 4; // 0..31
376+
const uint ib32 = ib8 / 4; // 0..7
377377

378-
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 4)) & 0xf;
378+
const uint scale = (bl.block.scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
379379
const uint qs = bl.block.qs[ib8];
380380
const uint qh = bl.block.qh[ib32];
381-
const uint sign = bl.block.qs[QUANT_K / 8 + ib8];
382-
383-
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(scale));
384-
const uint8_t g = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2])[idx & 3];
385-
386-
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
387-
return ret;
381+
const uint qhshift = 2 * (ib8 % 4);
382+
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
383+
384+
const float d = float(bl.block.d);
385+
const float db = d * 0.25 * (0.5 + scale);
386+
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
387+
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
388+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
389+
return float16_t(v[lsb]);
388390
}
389391
#endif
390392

@@ -399,25 +401,28 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3
399401

400402
float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
401403
{
402-
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
403-
const float16_t d = bl.block.d;
404-
const uint idx = coordInBlock[1];
405-
406-
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
407-
const uint ib4 = (idx & 0xFC) >> 4; // 0..63
408-
const uint is16 = QUANT_K / 8 + 2 * ib32; // index in packed16
404+
uint idx = coordInBlock[1];
405+
uint lsb = idx & 1;
406+
idx /= 2;
409407

410-
const uint8_t qs = bl.block.qs[ib4];
411-
const uint signscale = pack32(u16vec2(bl16.block.qs[is16], bl16.block.qs[is16+1]));
408+
const uint iqs = (idx % 128) / 2; // 0..63
409+
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
412410

413-
const float16_t dscale = bl.block.d * 0.5hf * (0.5hf + float16_t(signscale >> 28));
414-
uint sign = bitfieldExtract(signscale, 7 * int(ib4 & 3), 7);
415-
sign |= bitCount(sign) << 7;
416-
417-
const uint8_t g = unpack8(iq3xxs_grid[qs])[idx & 3];
418-
419-
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
420-
return ret;
411+
const float d = float(bl.block.d);
412+
const uint qs = bl.block.qs[iqs];
413+
const uint signs = pack32(u8vec4(
414+
bl.block.qs[is+0],
415+
bl.block.qs[is+1],
416+
bl.block.qs[is+2],
417+
bl.block.qs[is+3]
418+
));
419+
const float db = d * 0.5 * (0.5 + (signs >> 28));
420+
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
421+
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
422+
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
423+
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
424+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
425+
return float16_t(v[lsb]);
421426
}
422427
#endif
423428

@@ -428,26 +433,24 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3
428433

429434
float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
430435
{
431-
const float16_t d = bl.block.d;
432-
const uint idx = coordInBlock[1];
433-
434-
const uint iqs = (idx & 0xFC) >> 2; // 0..63
435-
const uint iqh = (idx & 0xE0) >> 5; // 0..7
436-
const uint qhbit = iqs & 7;
437-
const uint isgn = (idx & 0xF8) >> 3; // 0..31
438-
const uint is = (idx & 0xC0) >> 6; // 0..3
436+
uint idx = coordInBlock[1];
437+
uint lsb = idx & 1;
438+
idx /= 2;
439439

440-
const uint8_t scale = (bl.block.scales[is] >> ((idx & 0x20) >> 3)) & uint8_t(0xF);
441-
const float16_t dscale = d * (1.0hf + float16_t(2 * scale));
440+
const uint iqs = (idx % 128) / 2; // 0..63
441+
const uint iqh = iqs / 8;
442442

443+
const float d = float(bl.block.d);
443444
const uint qs = bl.block.qs[iqs];
444-
const uint qh = (bl.block.qh[iqh] << (8 - qhbit)) & 0x100;
445-
const uint8_t sign = bl.block.signs[isgn];
446-
447-
const uint g = unpack8(iq3s_grid[qs | qh])[idx & 3];
448-
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
449-
450-
return ret;
445+
const uint qh = bl.block.qh[iqh];
446+
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (2 * (idx % 4)));
447+
const uint scale = bl.block.scales[iqs / 16];
448+
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
449+
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
450+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
451+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
452+
453+
return float16_t(v[lsb]);
451454
}
452455
#endif
453456

0 commit comments

Comments
 (0)