@@ -877,7 +877,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
877
877
static const int qk = QK4_0 ;
878
878
879
879
assert (qk / 16 == 0 );
880
- assert (k % qk == 0 );
880
+ assert ( k % qk == 0 );
881
881
882
882
const int nb = k / qk ;
883
883
@@ -912,7 +912,7 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
912
912
const int qk = QK4_1 ;
913
913
914
914
assert (qk / 16 == 0 );
915
- assert (k % qk == 0 );
915
+ assert ( k % qk == 0 );
916
916
917
917
const int nb = k / qk ;
918
918
@@ -945,48 +945,37 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k
945
945
946
946
// reference implementation for deterministic creation of model files
947
947
static void quantize_row_q4_2_reference (const float * restrict x , block_q4_2 * restrict y , int k ) {
948
- assert ( k % QK4_2 == 0 ) ;
948
+ static const int qk = QK4_2 ;
949
949
950
- const int nb = k / QK4_2 ;
950
+ assert (qk / 16 == 0 );
951
+ assert ( k % qk == 0 );
952
+
953
+ const int nb = k / qk ;
951
954
952
955
for (int i = 0 ; i < nb ; i ++ ) {
953
956
float amax = 0.0f ; // absolute max
954
- float max = 0.0f ;
957
+ float max = 0.0f ;
955
958
956
- for (int l = 0 ; l < QK4_2 ; l ++ ) {
957
- const float v = x [i * QK4_2 + l ];
959
+ for (int l = 0 ; l < qk ; l ++ ) {
960
+ const float v = x [i * qk + l ];
958
961
if (amax < fabsf (v )) {
959
962
amax = fabsf (v );
960
- max = v ;
963
+ max = v ;
961
964
}
962
965
}
963
966
964
- const float d = max / -8 ;
965
-
967
+ const float d = max / -8 ;
966
968
const float id = d ? 1.0f /d : 0.0f ;
967
969
968
970
y [i ].d = GGML_FP32_TO_FP16 (d );
969
971
970
- for (int l = 0 ; l < QK4_2 ; l += 2 ) {
971
- const float v0 = x [i * QK4_2 + l + 0 ]* id ;
972
- const float v1 = x [i * QK4_2 + l + 1 ]* id ;
973
-
974
- const uint8_t vi0 = MIN (15 , (uint8_t )(v0 + 8.5f ));
975
- const uint8_t vi1 = MIN (15 , (uint8_t )(v1 + 8.5f ));
972
+ uint64_t qs [QK4_2 / 16 ] = {0 };
976
973
977
- assert (vi0 < 16 );
978
- assert (vi1 < 16 );
979
-
980
- y [i ].qs [l /2 ] = vi0 | (vi1 << 4 );
981
- }
974
+ nibbles_from_floats_64_0 (qk , x + i * qk , id , y [i ].qs , qs );
982
975
}
983
976
}
984
977
985
- static void quantize_row_q4_2 (const float * restrict x , void * restrict vy , int k ) {
986
- assert (k % QK4_2 == 0 );
987
-
988
- block_q4_2 * restrict y = vy ;
989
-
978
+ static void quantize_row_q4_2 (const float * restrict x , void * restrict y , int k ) {
990
979
quantize_row_q4_2_reference (x , y , k );
991
980
}
992
981
@@ -1324,7 +1313,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
1324
1313
static const int qk = QK4_0 ;
1325
1314
1326
1315
assert (qk / 16 == 0 );
1327
- assert (k % qk == 0 );
1316
+ assert ( k % qk == 0 );
1328
1317
1329
1318
const int nb = k / qk ;
1330
1319
@@ -1345,7 +1334,7 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
1345
1334
static const int qk = QK4_1 ;
1346
1335
1347
1336
assert (qk / 16 == 0 );
1348
- assert (k % qk == 0 );
1337
+ assert ( k % qk == 0 );
1349
1338
1350
1339
const int nb = k / qk ;
1351
1340
@@ -1363,31 +1352,23 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
1363
1352
}
1364
1353
}
1365
1354
1366
- static void dequantize_row_q4_2 (const void * restrict vx , float * restrict y , int k ) {
1367
- assert (k % QK4_2 == 0 );
1368
- const int nb = k / QK4_2 ;
1369
-
1370
- const block_q4_2 * restrict x = vx ;
1371
-
1372
- for (int i = 0 ; i < nb ; i ++ ) {
1373
- const float d = GGML_FP16_TO_FP32 (x [i ].d );
1355
+ static void dequantize_row_q4_2 (const block_q4_2 * restrict x , float * restrict y , int k ) {
1356
+ static const int qk = QK4_2 ;
1374
1357
1375
- const uint8_t * restrict pp = x [i ].qs ;
1358
+ assert (qk / 16 == 0 );
1359
+ assert ( k % qk == 0 );
1376
1360
1377
- for (int l = 0 ; l < QK4_2 ; l += 2 ) {
1378
- const uint8_t vi = pp [l /2 ];
1361
+ const int nb = k / qk ;
1379
1362
1380
- const int8_t vi0 = vi & 0x0F ;
1381
- const int8_t vi1 = vi >> 4 ;
1363
+ uint64_t qs [QK4_2 / 8 ];
1382
1364
1383
- const float v0 = ( vi0 - 8 ) * d ;
1384
- const float v1 = ( vi1 - 8 ) * d ;
1365
+ for ( int i = 0 ; i < nb ; i ++ ) {
1366
+ const float d = GGML_FP16_TO_FP32 ( x [ i ]. d ) ;
1385
1367
1386
- y [i * QK4_2 + l + 0 ] = v0 ;
1387
- y [i * QK4_2 + l + 1 ] = v1 ;
1368
+ const uint8_t * qsp = bytes_from_nibbles_64 (qk , x [i ].qs , qs );
1388
1369
1389
- assert (! isnan ( y [ i * QK4_2 + l + 0 ]));
1390
- assert (! isnan ( y [i * QK4_2 + l + 1 ])) ;
1370
+ for ( int l = 0 ; l < qk ; ++ l ) {
1371
+ y [i * qk + l ] = ( qsp [ l ] - 8 ) * d ;
1391
1372
}
1392
1373
}
1393
1374
}
@@ -1507,7 +1488,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1507
1488
.vec_dot_type = GGML_TYPE_Q8_1 ,
1508
1489
},
1509
1490
[GGML_TYPE_Q4_2 ] = {
1510
- .dequantize_row_q = dequantize_row_q4_2 ,
1491
+ .dequantize_row_q = ( dequantize_row_q_t ) dequantize_row_q4_2 ,
1511
1492
.quantize_row_q = quantize_row_q4_2 ,
1512
1493
.quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_2_reference ,
1513
1494
.quantize_row_q_dot = quantize_row_q8_0 ,
@@ -2432,11 +2413,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2432
2413
}
2433
2414
2434
2415
static void ggml_vec_dot_q4_2_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2435
- const int nb = n / QK8_0 ;
2416
+ const int qk = QK8_0 ;
2417
+ const int nb = n / qk ;
2436
2418
2437
- assert (n % QK8_0 == 0 );
2419
+ assert (n % qk == 0 );
2438
2420
assert (nb % 2 == 0 );
2439
- assert (QK8_0 == 2 * QK4_2 );
2421
+
2422
+ assert (qk == 2 * QK4_2 );
2440
2423
2441
2424
const block_q4_2 * restrict x = vx ;
2442
2425
const block_q8_0 * restrict y = vy ;
@@ -2472,12 +2455,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2472
2455
const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
2473
2456
const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2474
2457
2475
- // interleave
2476
- const int8x16_t v0_0lz = vzip1q_s8 (v0_0ls , v0_0hs );
2477
- const int8x16_t v0_0hz = vzip2q_s8 (v0_0ls , v0_0hs );
2478
- const int8x16_t v0_1lz = vzip1q_s8 (v0_1ls , v0_1hs );
2479
- const int8x16_t v0_1hz = vzip2q_s8 (v0_1ls , v0_1hs );
2480
-
2481
2458
// load y
2482
2459
const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2483
2460
const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
@@ -2486,22 +2463,22 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2486
2463
2487
2464
#if defined(__ARM_FEATURE_DOTPROD )
2488
2465
sumv0 = vmlaq_n_f32 (sumv0 , vaddq_f32 (
2489
- vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), GGML_FP16_TO_FP32 (x0_0 -> d )),
2490
- vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), GGML_FP16_TO_FP32 (x0_1 -> d ))), y0 -> d );
2466
+ vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0l )), GGML_FP16_TO_FP32 (x0_0 -> d )),
2467
+ vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hs , v1_0h )), GGML_FP16_TO_FP32 (x0_1 -> d ))), y0 -> d );
2491
2468
2492
2469
sumv1 = vmlaq_n_f32 (sumv1 , vaddq_f32 (
2493
- vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l )), GGML_FP16_TO_FP32 (x1_0 -> d )),
2494
- vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h )), GGML_FP16_TO_FP32 (x1_1 -> d ))), y1 -> d );
2470
+ vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1l )), GGML_FP16_TO_FP32 (x1_0 -> d )),
2471
+ vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hs , v1_1h )), GGML_FP16_TO_FP32 (x1_1 -> d ))), y1 -> d );
2495
2472
#else
2496
- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
2497
- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
2498
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
2499
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
2500
-
2501
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
2502
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
2503
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
2504
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
2473
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0l ));
2474
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0l ));
2475
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0h ));
2476
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0h ));
2477
+
2478
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1l ));
2479
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1l ));
2480
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1h ));
2481
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1h ));
2505
2482
2506
2483
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2507
2484
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
0 commit comments