@@ -163,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
163
163
block_q4_K_packed16 block;
164
164
};
165
165
166
+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
167
+ block_q4_K_packed128 block;
168
+ };
169
+
166
170
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
167
171
{
168
172
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
173
+ decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
169
174
const uint idx = coordInBlock[1];
170
175
171
176
const uint b = (idx & 0x20) >> 5; // 0,1
172
177
const uint is = (idx & 0xE0) >> 5; // 0..7
173
178
174
- const f16vec2 loadd = bl.block.d;
179
+ uvec4 v = bl128.block.q4k[0];
180
+
181
+ const f16vec2 loadd = unpackFloat2x16(v.x);
175
182
176
183
uint32_t sc;
177
184
uint32_t mbyte;
178
185
179
- uint32_t scidx0 = (is < 4) ? is : (is + 4);
180
- uint32_t scidx1 = (is < 4) ? is : (is - 4);
181
- uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
182
- uint32_t scidxshift1 = (is < 4) ? 0 : 2;
183
- uint32_t mbidx0 = is + 4;
184
- uint32_t mbidx1 = (is < 4) ? is + 4 : is;
185
- uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
186
- uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
187
- uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
188
- uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
186
+ uint32_t scale0 = v.y;
187
+ uint32_t scale4 = v.z;
188
+ uint32_t scale8 = v.w;
189
189
190
- sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
191
- mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
190
+ uint32_t sc_lo = scale0;
191
+ uint32_t mb_lo = scale4;
192
+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
193
+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
194
+
195
+ sc = is < 4 ? sc_lo : sc_hi;
196
+ mbyte = is < 4 ? mb_lo : mb_hi;
197
+ sc = sc >> (8 * (is & 3));
198
+ mbyte = mbyte >> (8 * (is & 3));
199
+ sc &= 0x3F;
200
+ mbyte &= 0x3F;
192
201
193
202
const float16_t d = loadd.x * float16_t(sc);
194
203
const float16_t m = loadd.y * float16_t(mbyte);
195
204
196
205
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
197
- qs = (qs >> (b * 4)) & 0x0F0F;
198
- qs = unpack8(qs)[idx & 1];
206
+ qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
199
207
200
208
float16_t ret = d * float16_t(qs) - m;
201
209
@@ -210,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
210
218
block_q5_K_packed16 block;
211
219
};
212
220
221
+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
222
+ block_q5_K_packed128 block;
223
+ };
224
+
213
225
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
214
226
{
215
227
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
228
+ decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
216
229
const uint idx = coordInBlock[1];
217
230
218
231
const uint b = (idx & 0x20) >> 5; // 0,1
219
232
const uint is = (idx & 0xE0) >> 5; // 0..7
220
233
221
- const uint32_t hm = 0x0101 << is ;
234
+ uvec4 v = bl128.block.q5k[0] ;
222
235
223
- const f16vec2 loadd = bl.block.d ;
236
+ const f16vec2 loadd = unpackFloat2x16(v.x) ;
224
237
225
238
uint32_t sc;
226
239
uint32_t mbyte;
227
240
228
- uint32_t scidx0 = (is < 4) ? is : (is + 4);
229
- uint32_t scidx1 = (is < 4) ? is : (is - 4);
230
- uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
231
- uint32_t scidxshift1 = (is < 4) ? 0 : 2;
232
- uint32_t mbidx0 = is + 4;
233
- uint32_t mbidx1 = (is < 4) ? is + 4 : is;
234
- uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
235
- uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
236
- uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
237
- uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
241
+ uint32_t scale0 = v.y;
242
+ uint32_t scale4 = v.z;
243
+ uint32_t scale8 = v.w;
238
244
239
- sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
240
- mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
245
+ uint32_t sc_lo = scale0;
246
+ uint32_t mb_lo = scale4;
247
+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
248
+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
249
+
250
+ sc = is < 4 ? sc_lo : sc_hi;
251
+ mbyte = is < 4 ? mb_lo : mb_hi;
252
+ sc = sc >> (8 * (is & 3));
253
+ mbyte = mbyte >> (8 * (is & 3));
254
+ sc &= 0x3F;
255
+ mbyte &= 0x3F;
241
256
242
257
const float16_t d = loadd.x * float16_t(sc);
243
258
const float16_t m = loadd.y * float16_t(mbyte);
244
259
245
260
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
246
- qh = qh & hm;
247
- qh = unpack8(qh)[idx & 1];
261
+ qh = ((qh >> is) & 0x101) << 4;
248
262
249
263
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
250
264
qs = (qs >> (b * 4)) & 0x0F0F;
251
- qs = unpack8(qs)[idx & 1];
265
+ qs = unpack8(qs | qh )[idx & 1];
252
266
253
- float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0)) ) - m;
267
+ float16_t ret = d * (float16_t(qs)) - m;
254
268
255
269
return ret;
256
270
}
0 commit comments