@@ -656,10 +656,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
656
656
#define QK8_0 32
657
657
typedef struct {
658
658
float d ; // delta
659
- float s ; // d * sum(qs[i])
659
+ float s0 ; // d * sum(qs[i]) low
660
+ float s1 ; // d * sum(qs[i]) high
660
661
int8_t qs [QK8_0 ]; // quants
661
662
} block_q8_0 ;
662
- static_assert (sizeof (block_q8_0 ) == 2 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663
+ static_assert (sizeof (block_q8_0 ) == 3 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663
664
664
665
665
666
// reference implementation for deterministic creation of model files
@@ -1299,13 +1300,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1299
1300
1300
1301
y [i ].d = d ;
1301
1302
1302
- int sum = 0 ;
1303
- for (int l = 0 ; l < QK8_0 ; ++ l ) {
1304
- const float v = x [i * QK8_0 + l ]* id ;
1305
- y [i ].qs [l ] = roundf (v );
1306
- sum += y [i ].qs [l ];
1303
+ int sum0 = 0 ;
1304
+ int sum1 = 0 ;
1305
+
1306
+ for (int l = 0 ; l < QK8_0 /2 ; ++ l ) {
1307
+ const float v0 = x [i * QK8_0 + l ]* id ;
1308
+ const float v1 = x [i * QK8_0 + QK8_0 /2 + l ]* id ;
1309
+
1310
+ y [i ].qs [ l ] = roundf (v0 );
1311
+ y [i ].qs [QK8_0 /2 + l ] = roundf (v1 );
1312
+
1313
+ sum0 += y [i ].qs [ l ];
1314
+ sum1 += y [i ].qs [QK8_0 /2 + l ];
1307
1315
}
1308
- y [i ].s = d * sum ;
1316
+
1317
+ y [i ].s0 = d * sum0 ;
1318
+ y [i ].s1 = d * sum1 ;
1309
1319
}
1310
1320
}
1311
1321
@@ -1335,9 +1345,24 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1335
1345
1336
1346
y [i ].d = d ;
1337
1347
1338
- int32x4_t accv = vdupq_n_s32 (0 );
1348
+ int32x4_t accv0 = vdupq_n_s32 (0 );
1349
+ int32x4_t accv1 = vdupq_n_s32 (0 );
1339
1350
1340
- for (int l = 0 ; l < 8 ; l ++ ) {
1351
+ // low half
1352
+ for (int l = 0 ; l < 4 ; l ++ ) {
1353
+ const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1354
+ const int32x4_t vi = vcvtnq_s32_f32 (v );
1355
+
1356
+ y [i ].qs [4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
1357
+ y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1358
+ y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1359
+ y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1360
+
1361
+ accv0 = vaddq_s32 (accv0 , vi );
1362
+ }
1363
+
1364
+ // high half
1365
+ for (int l = 4 ; l < 8 ; l ++ ) {
1341
1366
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1342
1367
const int32x4_t vi = vcvtnq_s32_f32 (v );
1343
1368
@@ -1346,12 +1371,17 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1346
1371
y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1347
1372
y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1348
1373
1349
- accv = vaddq_s32 (accv , vi );
1374
+ accv1 = vaddq_s32 (accv1 , vi );
1350
1375
}
1351
- int32_t sum = vaddvq_s32 (accv );
1352
- y [i ].s = d * sum ;
1376
+
1377
+ const int32_t sum0 = vaddvq_s32 (accv0 );
1378
+ const int32_t sum1 = vaddvq_s32 (accv1 );
1379
+
1380
+ y [i ].s0 = d * sum0 ;
1381
+ y [i ].s1 = d * sum1 ;
1353
1382
}
1354
1383
#elif defined(__AVX2__ ) || defined(__AVX__ )
1384
+ // TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1355
1385
for (int i = 0 ; i < nb ; i ++ ) {
1356
1386
// Load elements into 4 AVX vectors
1357
1387
__m256 v0 = _mm256_loadu_ps ( x );
@@ -2395,7 +2425,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2395
2425
const block_q8_0 * restrict y0 = & y [i + 0 ];
2396
2426
const block_q8_0 * restrict y1 = & y [i + 1 ];
2397
2427
2398
- sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
2428
+ sum8 += x0 -> d * ( y0 -> s0 + y0 -> s1 ) + x1 -> d * ( y1 -> s0 + y1 -> s1 ) ;
2399
2429
2400
2430
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2401
2431
@@ -2562,7 +2592,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2562
2592
const block_q8_0 * restrict y0 = & y [i + 0 ];
2563
2593
const block_q8_0 * restrict y1 = & y [i + 1 ];
2564
2594
2565
- summs += x0 -> m * y0 -> s + x1 -> m * y1 -> s ;
2595
+ summs += x0 -> m * ( y0 -> s0 + y0 -> s1 ) + x1 -> m * ( y1 -> s0 + y1 -> s1 ) ;
2566
2596
2567
2597
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2568
2598
@@ -2589,8 +2619,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2589
2619
2590
2620
#if defined(__ARM_FEATURE_DOTPROD )
2591
2621
// dot product into int32x4_t
2592
- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2593
- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2622
+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l ), v0_0hz , v1_0h );
2623
+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l ), v0_1hz , v1_1h );
2594
2624
2595
2625
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
2596
2626
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
@@ -2845,6 +2875,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2845
2875
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2846
2876
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2847
2877
2878
+ float summs = 0.0f ;
2879
+
2848
2880
for (int i = 0 ; i < nb ; i += 2 ) {
2849
2881
const block_q4_3 * restrict x0_0 = & x [2 * (i + 0 ) + 0 ];
2850
2882
const block_q4_3 * restrict x0_1 = & x [2 * (i + 0 ) + 1 ];
@@ -2854,18 +2886,16 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2854
2886
const block_q8_0 * restrict y0 = & y [i + 0 ];
2855
2887
const block_q8_0 * restrict y1 = & y [i + 1 ];
2856
2888
2889
+ summs += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 + GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
2890
+ summs += GGML_FP16_TO_FP32 (x1_0 -> m ) * y1 -> s0 + GGML_FP16_TO_FP32 (x1_1 -> m ) * y1 -> s1 ;
2891
+
2857
2892
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2858
2893
2859
2894
const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2860
2895
const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
2861
2896
const float x1_0d = GGML_FP16_TO_FP32 (x1_0 -> d );
2862
2897
const float x1_1d = GGML_FP16_TO_FP32 (x1_1 -> d );
2863
2898
2864
- const float x0_0m = GGML_FP16_TO_FP32 (x0_0 -> m );
2865
- const float x0_1m = GGML_FP16_TO_FP32 (x0_1 -> m );
2866
- const float x1_0m = GGML_FP16_TO_FP32 (x1_0 -> m );
2867
- const float x1_1m = GGML_FP16_TO_FP32 (x1_1 -> m );
2868
-
2869
2899
const uint8x16_t v0_0 = vcombine_u8 (vld1_u8 (x0_0 -> qs ), vld1_u8 (x0_1 -> qs ));
2870
2900
const uint8x16_t v0_1 = vcombine_u8 (vld1_u8 (x1_0 -> qs ), vld1_u8 (x1_1 -> qs ));
2871
2901
@@ -2887,17 +2917,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2887
2917
const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2888
2918
const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2889
2919
2890
- const int16x8_t sy0_0 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0l )), vmovl_s8 (vget_high_s8 (v1_0l )));
2891
- const int16x8_t sy0_1 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0h )), vmovl_s8 (vget_high_s8 (v1_0h )));
2892
-
2893
- const int16x8_t sy1_0 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1l )), vmovl_s8 (vget_high_s8 (v1_1l )));
2894
- const int16x8_t sy1_1 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1h )), vmovl_s8 (vget_high_s8 (v1_1h )));
2895
-
2896
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy0_0 ), vget_high_s16 (sy0_0 ))), x0_0m * y0 -> d );
2897
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy0_1 ), vget_high_s16 (sy0_1 ))), x0_1m * y0 -> d );
2898
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy1_0 ), vget_high_s16 (sy1_0 ))), x1_0m * y1 -> d );
2899
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy1_1 ), vget_high_s16 (sy1_1 ))), x1_1m * y1 -> d );
2900
-
2901
2920
#if defined(__ARM_FEATURE_DOTPROD )
2902
2921
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), x0_0d * y0 -> d );
2903
2922
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
@@ -2926,7 +2945,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2926
2945
#endif
2927
2946
}
2928
2947
2929
- * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 ( sumv1 ) ;
2948
+ sumf = vaddvq_f32 (vaddq_f32 ( sumv0 , sumv1 )) + summs ;
2930
2949
#elif defined(__AVX2__ )
2931
2950
// Initialize accumulator with zeros
2932
2951
__m256 acc = _mm256_setzero_ps ();
0 commit comments