@@ -656,10 +656,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
656
656
#define QK8_0 32
657
657
typedef struct {
658
658
float d ; // delta
659
- float s ; // d * sum(qs[i])
659
+ float s0 ; // d * sum(qs[i]) low
660
+ float s1 ; // d * sum(qs[i]) high
660
661
int8_t qs [QK8_0 ]; // quants
661
662
} block_q8_0 ;
662
- static_assert (sizeof (block_q8_0 ) == 2 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663
+ static_assert (sizeof (block_q8_0 ) == 3 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663
664
664
665
665
666
// reference implementation for deterministic creation of model files
@@ -1299,13 +1300,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1299
1300
1300
1301
y [i ].d = d ;
1301
1302
1302
- int sum = 0 ;
1303
- for (int l = 0 ; l < QK8_0 ; ++ l ) {
1304
- const float v = x [i * QK8_0 + l ]* id ;
1305
- y [i ].qs [l ] = roundf (v );
1306
- sum += y [i ].qs [l ];
1303
+ int sum0 = 0 ;
1304
+ int sum1 = 0 ;
1305
+
1306
+ for (int l = 0 ; l < QK8_0 /2 ; ++ l ) {
1307
+ const float v0 = x [i * QK8_0 + l ]* id ;
1308
+ const float v1 = x [i * QK8_0 + QK8_0 /2 + l ]* id ;
1309
+
1310
+ y [i ].qs [ l ] = roundf (v0 );
1311
+ y [i ].qs [QK8_0 /2 + l ] = roundf (v1 );
1312
+
1313
+ sum0 += y [i ].qs [ l ];
1314
+ sum1 += y [i ].qs [QK8_0 /2 + l ];
1307
1315
}
1308
- y [i ].s = d * sum ;
1316
+
1317
+ y [i ].s0 = d * sum0 ;
1318
+ y [i ].s1 = d * sum1 ;
1309
1319
}
1310
1320
}
1311
1321
@@ -1335,9 +1345,24 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1335
1345
1336
1346
y [i ].d = d ;
1337
1347
1338
- int32x4_t accv = vdupq_n_s32 (0 );
1348
+ int32x4_t accv0 = vdupq_n_s32 (0 );
1349
+ int32x4_t accv1 = vdupq_n_s32 (0 );
1339
1350
1340
- for (int l = 0 ; l < 8 ; l ++ ) {
1351
+ // low half
1352
+ for (int l = 0 ; l < 4 ; l ++ ) {
1353
+ const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1354
+ const int32x4_t vi = vcvtnq_s32_f32 (v );
1355
+
1356
+ y [i ].qs [4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
1357
+ y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1358
+ y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1359
+ y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1360
+
1361
+ accv0 = vaddq_s32 (accv0 , vi );
1362
+ }
1363
+
1364
+ // high half
1365
+ for (int l = 4 ; l < 8 ; l ++ ) {
1341
1366
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1342
1367
const int32x4_t vi = vcvtnq_s32_f32 (v );
1343
1368
@@ -1346,12 +1371,17 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1346
1371
y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1347
1372
y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1348
1373
1349
- accv = vaddq_s32 (accv , vi );
1374
+ accv1 = vaddq_s32 (accv1 , vi );
1350
1375
}
1351
- int32_t sum = vaddvq_s32 (accv );
1352
- y [i ].s = d * sum ;
1376
+
1377
+ const int32_t sum0 = vaddvq_s32 (accv0 );
1378
+ const int32_t sum1 = vaddvq_s32 (accv1 );
1379
+
1380
+ y [i ].s0 = d * sum0 ;
1381
+ y [i ].s1 = d * sum1 ;
1353
1382
}
1354
1383
#elif defined(__AVX2__ ) || defined(__AVX__ )
1384
+ // TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1355
1385
for (int i = 0 ; i < nb ; i ++ ) {
1356
1386
// Load elements into 4 AVX vectors
1357
1387
__m256 v0 = _mm256_loadu_ps ( x );
@@ -1398,7 +1428,9 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1398
1428
1399
1429
#if defined(__AVX2__ )
1400
1430
// Compute the sum of the quants and set y[i].s
1401
- y [i ].s = d * hsum_i32_8 (_mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 )));
1431
+ //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1432
+ y [i ].s0 = d * hsum_i32_8 (_mm256_add_epi32 (i0 , i1 ));
1433
+ y [i ].s1 = d * hsum_i32_8 (_mm256_add_epi32 (i2 , i3 ));
1402
1434
1403
1435
// Convert int32 to int16
1404
1436
i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@@ -2395,7 +2427,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2395
2427
const block_q8_0 * restrict y0 = & y [i + 0 ];
2396
2428
const block_q8_0 * restrict y1 = & y [i + 1 ];
2397
2429
2398
- sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
2430
+ sum8 += x0 -> d * ( y0 -> s0 + y0 -> s1 ) + x1 -> d * ( y1 -> s0 + y1 -> s1 ) ;
2399
2431
2400
2432
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2401
2433
@@ -2562,7 +2594,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2562
2594
const block_q8_0 * restrict y0 = & y [i + 0 ];
2563
2595
const block_q8_0 * restrict y1 = & y [i + 1 ];
2564
2596
2565
- summs += x0 -> m * y0 -> s + x1 -> m * y1 -> s ;
2597
+ summs += x0 -> m * ( y0 -> s0 + y0 -> s1 ) + x1 -> m * ( y1 -> s0 + y1 -> s1 ) ;
2566
2598
2567
2599
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2568
2600
@@ -2575,22 +2607,22 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2575
2607
const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2576
2608
const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2577
2609
2610
+ // interleave
2611
+ const int8x16_t v0_0lz = vzip1q_s8 (v0_0l , v0_0h );
2612
+ const int8x16_t v0_0hz = vzip2q_s8 (v0_0l , v0_0h );
2613
+ const int8x16_t v0_1lz = vzip1q_s8 (v0_1l , v0_1h );
2614
+ const int8x16_t v0_1hz = vzip2q_s8 (v0_1l , v0_1h );
2615
+
2578
2616
// load y
2579
2617
const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2580
2618
const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2581
2619
const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2582
2620
const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2583
2621
2584
- // interleave
2585
- const int8x16_t v1_0ls = vuzp1q_s8 (v1_0l , v1_0h );
2586
- const int8x16_t v1_0hs = vuzp2q_s8 (v1_0l , v1_0h );
2587
- const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2588
- const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2589
-
2590
2622
#if defined(__ARM_FEATURE_DOTPROD )
2591
2623
// dot product into int32x4_t
2592
- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2593
- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2624
+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l ), v0_0hz , v1_0h );
2625
+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l ), v0_1hz , v1_1h );
2594
2626
2595
2627
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
2596
2628
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
@@ -2627,7 +2659,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2627
2659
const float * d0 = & x [i ].d ;
2628
2660
const float * d1 = & y [i ].d ;
2629
2661
2630
- summs += x [i ].m * y [i ].s ;
2662
+ summs += x [i ].m * ( y [i ].s0 + y [ i ]. s1 ) ;
2631
2663
2632
2664
const __m256 d0v = _mm256_broadcast_ss ( d0 );
2633
2665
const __m256 d1v = _mm256_broadcast_ss ( d1 );
@@ -2845,88 +2877,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2845
2877
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2846
2878
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2847
2879
2848
- for (int i = 0 ; i < nb ; i += 2 ) {
2880
+ float summs0 = 0.0f ;
2881
+ float summs1 = 0.0f ;
2882
+
2883
+ for (int i = 0 ; i < nb ; ++ i ) {
2849
2884
const block_q4_3 * restrict x0_0 = & x [2 * (i + 0 ) + 0 ];
2850
2885
const block_q4_3 * restrict x0_1 = & x [2 * (i + 0 ) + 1 ];
2851
- const block_q4_3 * restrict x1_0 = & x [2 * (i + 1 ) + 0 ];
2852
- const block_q4_3 * restrict x1_1 = & x [2 * (i + 1 ) + 1 ];
2853
2886
2854
2887
const block_q8_0 * restrict y0 = & y [i + 0 ];
2855
- const block_q8_0 * restrict y1 = & y [i + 1 ];
2856
-
2857
- const uint8x16_t m4b = vdupq_n_u8 (0xf );
2858
-
2859
- const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2860
- const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
2861
- const float x1_0d = GGML_FP16_TO_FP32 (x1_0 -> d );
2862
- const float x1_1d = GGML_FP16_TO_FP32 (x1_1 -> d );
2863
2888
2864
- const float x0_0m = GGML_FP16_TO_FP32 (x0_0 -> m );
2865
- const float x0_1m = GGML_FP16_TO_FP32 (x0_1 -> m );
2866
- const float x1_0m = GGML_FP16_TO_FP32 (x1_0 -> m );
2867
- const float x1_1m = GGML_FP16_TO_FP32 (x1_1 -> m );
2889
+ summs0 += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 ;
2890
+ summs1 += GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
2868
2891
2869
2892
const uint8x16_t v0_0 = vcombine_u8 (vld1_u8 (x0_0 -> qs ), vld1_u8 (x0_1 -> qs ));
2870
- const uint8x16_t v0_1 = vcombine_u8 (vld1_u8 (x1_0 -> qs ), vld1_u8 (x1_1 -> qs ));
2871
2893
2872
2894
// 4-bit -> 8-bit
2873
- const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
2895
+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , vdupq_n_u8 ( 0xf ) ));
2874
2896
const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
2875
- const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2876
- const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2877
2897
2878
2898
// interleave
2879
2899
const int8x16_t v0_0lz = vzip1q_s8 (v0_0l , v0_0h );
2880
2900
const int8x16_t v0_0hz = vzip2q_s8 (v0_0l , v0_0h );
2881
- const int8x16_t v0_1lz = vzip1q_s8 (v0_1l , v0_1h );
2882
- const int8x16_t v0_1hz = vzip2q_s8 (v0_1l , v0_1h );
2883
2901
2884
2902
// load y
2885
2903
const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2886
2904
const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2887
- const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2888
- const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2889
-
2890
- const int16x8_t sy0_0 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0l )), vmovl_s8 (vget_high_s8 (v1_0l )));
2891
- const int16x8_t sy0_1 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0h )), vmovl_s8 (vget_high_s8 (v1_0h )));
2892
2905
2893
- const int16x8_t sy1_0 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1l )), vmovl_s8 (vget_high_s8 (v1_1l )));
2894
- const int16x8_t sy1_1 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1h )), vmovl_s8 (vget_high_s8 (v1_1h )));
2895
-
2896
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy0_0 ), vget_high_s16 (sy0_0 ))), x0_0m * y0 -> d );
2897
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy0_1 ), vget_high_s16 (sy0_1 ))), x0_1m * y0 -> d );
2898
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy1_0 ), vget_high_s16 (sy1_0 ))), x1_0m * y1 -> d );
2899
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy1_1 ), vget_high_s16 (sy1_1 ))), x1_1m * y1 -> d );
2906
+ const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2907
+ const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
2900
2908
2901
2909
#if defined(__ARM_FEATURE_DOTPROD )
2902
2910
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), x0_0d * y0 -> d );
2903
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
2904
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l )), x1_0d * y1 -> d );
2905
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h )), x1_1d * y1 -> d );
2911
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
2906
2912
#else
2907
2913
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
2908
2914
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
2909
2915
const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
2910
2916
const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
2911
2917
2912
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
2913
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
2914
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
2915
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
2916
-
2917
2918
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2918
2919
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2919
- const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2920
- const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
2921
2920
2922
2921
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl0 ), x0_0d * y0 -> d );
2923
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
2924
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (pl1 ), x1_0d * y1 -> d );
2925
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph1 ), x1_1d * y1 -> d );
2922
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
2926
2923
#endif
2927
2924
}
2928
2925
2929
- * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 ( sumv1 ) ;
2926
+ * s = vaddvq_f32 (vaddq_f32 ( sumv0 , sumv1 )) + summs0 + summs1 ;
2930
2927
#elif defined(__AVX2__ )
2931
2928
// Initialize accumulator with zeros
2932
2929
__m256 acc = _mm256_setzero_ps ();
@@ -2971,9 +2968,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2971
2968
const float d1 = GGML_FP16_TO_FP32 (x [2 * i + 1 ].d );
2972
2969
const float m1 = GGML_FP16_TO_FP32 (x [2 * i + 1 ].m );
2973
2970
2974
- int sy_0 = 0 ;
2975
- int sy_1 = 0 ;
2976
-
2977
2971
int sxy_0 = 0 ;
2978
2972
int sxy_1 = 0 ;
2979
2973
@@ -2993,15 +2987,11 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2993
2987
const int y0_1 = y0 [2 * (j + QK8_0 /4 ) + 0 ];
2994
2988
const int y1_1 = y0 [2 * (j + QK8_0 /4 ) + 1 ];
2995
2989
2996
- sy_0 += y0_0 + y1_0 ;
2997
- sy_1 += y0_1 + y1_1 ;
2998
-
2999
2990
sxy_0 += x0_0 * y0_0 + x1_0 * y1_0 ;
3000
2991
sxy_1 += x0_1 * y0_1 + x1_1 * y1_1 ;
3001
2992
}
3002
2993
3003
- sumf += (d0 * sxy_0 + m0 * sy_0 )* y [i ].d ;
3004
- sumf += (d1 * sxy_1 + m1 * sy_1 )* y [i ].d ;
2994
+ sumf += (d0 * sxy_0 + d1 * sxy_1 )* y [i ].d + m0 * y [i ].s0 + m1 * y [i ].s1 ;
3005
2995
}
3006
2996
* s = sumf ;
3007
2997
#endif
0 commit comments