@@ -447,7 +447,17 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
447
447
// we assume that the yl's have been multiplied with the appropriate scale factor
448
448
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
449
449
inline float block_q_n_dot_y (device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
450
- // TODO
450
+ float d = qb_curr->d ;
451
+ float2 acc = 0 .f ;
452
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2 );
453
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh );
454
+ for (int i = 0 ; i < 8 ; i+=2 ) {
455
+ acc[0 ] += yl[i + 0 ] * ((qs[i / 2 ] & 0x000F ) | ((qh >> (i+0 +il ) << 4 ) & 0x0010 ))
456
+ + yl[i + 1 ] * ((qs[i / 2 ] & 0x0F00 ) | ((qh >> (i+1 +il ) << 12 ) & 0x1000 ));
457
+ acc[1 ] += yl[i + 8 ] * ((qs[i / 2 ] & 0x00F0 ) | ((qh >> (i+0 +il+QK5_0/2 ) << 4 ) & 0x0010 ))
458
+ + yl[i + 9 ] * ((qs[i / 2 ] & 0xF000 ) | ((qh >> (i+1 +il+QK5_0/2 ) << 12 ) & 0x1000 ));
459
+ }
460
+ return d * (sumy * -16 .f + acc[0 ] + acc[1 ]);
451
461
}
452
462
453
463
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -2225,13 +2235,13 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
2225
2235
2226
2236
const int x_mv = (il ? 4 : 0 );
2227
2237
2228
- const int gh_mv = (il ? 12 : 0 );
2229
- const int gh_bk = (il ? 0 : 4 );
2238
+ const int qh_mv = (il ? 12 : 0 );
2239
+ const int qh_bk = (il ? 0 : 4 );
2230
2240
2231
2241
for (int i = 0 ; i < 8 ; i++) {
2232
2242
// extract the 5-th bits for x0 and x1
2233
- const uint8_t xh_0 = ((qh >> (gh_mv + 2 *i )) << gh_bk ) & 0x10 ;
2234
- const uint8_t xh_1 = ((qh >> (gh_mv + 2 *i+1 )) << gh_bk ) & 0x10 ;
2243
+ const uint8_t xh_0 = ((qh >> (qh_mv + 2 *i )) << qh_bk ) & 0x10 ;
2244
+ const uint8_t xh_1 = ((qh >> (qh_mv + 2 *i+1 )) << qh_bk ) & 0x10 ;
2235
2245
2236
2246
// combine the 4-bits from qs with the 5th bit
2237
2247
const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0);
0 commit comments