Skip to content

Commit 0326138

Browse files
jeffbolznvmglambda
authored andcommitted
vulkan: optimize coopmat2 q4_k/q5_k dequant functions. (ggml-org#11206)
Do masking on whole dwords, fetch all scales at once.
1 parent 0c52585 commit 0326138

File tree

2 files changed

+57
-33
lines changed

2 files changed

+57
-33
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -163,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
163163
block_q4_K_packed16 block;
164164
};
165165

166+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
167+
block_q4_K_packed128 block;
168+
};
169+
166170
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
167171
{
168172
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
173+
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
169174
const uint idx = coordInBlock[1];
170175

171176
const uint b = (idx & 0x20) >> 5; // 0,1
172177
const uint is = (idx & 0xE0) >> 5; // 0..7
173178

174-
const f16vec2 loadd = bl.block.d;
179+
uvec4 v = bl128.block.q4k[0];
180+
181+
const f16vec2 loadd = unpackFloat2x16(v.x);
175182

176183
uint32_t sc;
177184
uint32_t mbyte;
178185

179-
uint32_t scidx0 = (is < 4) ? is : (is + 4);
180-
uint32_t scidx1 = (is < 4) ? is : (is - 4);
181-
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
182-
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
183-
uint32_t mbidx0 = is + 4;
184-
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
185-
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
186-
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
187-
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
188-
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
186+
uint32_t scale0 = v.y;
187+
uint32_t scale4 = v.z;
188+
uint32_t scale8 = v.w;
189189

190-
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
191-
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
190+
uint32_t sc_lo = scale0;
191+
uint32_t mb_lo = scale4;
192+
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
193+
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
194+
195+
sc = is < 4 ? sc_lo : sc_hi;
196+
mbyte = is < 4 ? mb_lo : mb_hi;
197+
sc = sc >> (8 * (is & 3));
198+
mbyte = mbyte >> (8 * (is & 3));
199+
sc &= 0x3F;
200+
mbyte &= 0x3F;
192201

193202
const float16_t d = loadd.x * float16_t(sc);
194203
const float16_t m = loadd.y * float16_t(mbyte);
195204

196205
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
197-
qs = (qs >> (b * 4)) & 0x0F0F;
198-
qs = unpack8(qs)[idx & 1];
206+
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
199207

200208
float16_t ret = d * float16_t(qs) - m;
201209

@@ -210,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
210218
block_q5_K_packed16 block;
211219
};
212220

221+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
222+
block_q5_K_packed128 block;
223+
};
224+
213225
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
214226
{
215227
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
228+
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
216229
const uint idx = coordInBlock[1];
217230

218231
const uint b = (idx & 0x20) >> 5; // 0,1
219232
const uint is = (idx & 0xE0) >> 5; // 0..7
220233

221-
const uint32_t hm = 0x0101 << is;
234+
uvec4 v = bl128.block.q5k[0];
222235

223-
const f16vec2 loadd = bl.block.d;
236+
const f16vec2 loadd = unpackFloat2x16(v.x);
224237

225238
uint32_t sc;
226239
uint32_t mbyte;
227240

228-
uint32_t scidx0 = (is < 4) ? is : (is + 4);
229-
uint32_t scidx1 = (is < 4) ? is : (is - 4);
230-
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
231-
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
232-
uint32_t mbidx0 = is + 4;
233-
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
234-
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
235-
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
236-
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
237-
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
241+
uint32_t scale0 = v.y;
242+
uint32_t scale4 = v.z;
243+
uint32_t scale8 = v.w;
238244

239-
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
240-
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
245+
uint32_t sc_lo = scale0;
246+
uint32_t mb_lo = scale4;
247+
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
248+
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
249+
250+
sc = is < 4 ? sc_lo : sc_hi;
251+
mbyte = is < 4 ? mb_lo : mb_hi;
252+
sc = sc >> (8 * (is & 3));
253+
mbyte = mbyte >> (8 * (is & 3));
254+
sc &= 0x3F;
255+
mbyte &= 0x3F;
241256

242257
const float16_t d = loadd.x * float16_t(sc);
243258
const float16_t m = loadd.y * float16_t(mbyte);
244259

245260
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
246-
qh = qh & hm;
247-
qh = unpack8(qh)[idx & 1];
261+
qh = ((qh >> is) & 0x101) << 4;
248262

249263
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
250264
qs = (qs >> (b * 4)) & 0x0F0F;
251-
qs = unpack8(qs)[idx & 1];
265+
qs = unpack8(qs | qh)[idx & 1];
252266

253-
float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m;
267+
float16_t ret = d * (float16_t(qs)) - m;
254268

255269
return ret;
256270
}

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ struct block_q4_K_packed32
227227
uint32_t qs[QUANT_K_Q4_K/2/4];
228228
};
229229

230+
struct block_q4_K_packed128
231+
{
232+
uvec4 q4k[9];
233+
};
234+
230235
#if defined(DATA_A_Q4_K)
231236
#define QUANT_K QUANT_K_Q4_K
232237
#define A_TYPE block_q4_K
@@ -252,6 +257,11 @@ struct block_q5_K_packed16
252257
uint16_t qs[QUANT_K_Q5_K/2/2];
253258
};
254259

260+
struct block_q5_K_packed128
261+
{
262+
uvec4 q5k[11];
263+
};
264+
255265
#if defined(DATA_A_Q5_K)
256266
#define QUANT_K QUANT_K_Q5_K
257267
#define A_TYPE block_q5_K

0 commit comments

Comments
 (0)