@@ -657,9 +657,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
657
657
#define QK8_0 32
658
658
typedef struct {
659
659
float d ; // delta
660
+ float s ; // d * sum(qs[i])
660
661
int8_t qs [QK8_0 ]; // quants
661
662
} block_q8_0 ;
662
- static_assert (sizeof (block_q8_0 ) == sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663
+ static_assert (sizeof (block_q8_0 ) == 2 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663
664
664
665
665
666
// reference implementation for deterministic creation of model files
@@ -1299,10 +1300,13 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1299
1300
1300
1301
y [i ].d = d ;
1301
1302
1303
+ int sum = 0 ;
1302
1304
for (int l = 0 ; l < QK8_0 ; ++ l ) {
1303
1305
const float v = x [i * QK8_0 + l ]* id ;
1304
1306
y [i ].qs [l ] = roundf (v );
1307
+ sum += y [i ].qs [l ];
1305
1308
}
1309
+ y [i ].s = d * sum ;
1306
1310
}
1307
1311
}
1308
1312
@@ -1332,6 +1336,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1332
1336
1333
1337
y [i ].d = d ;
1334
1338
1339
+ int32x4_t accv = vdupq_n_s32 (0 );
1340
+
1335
1341
for (int l = 0 ; l < 8 ; l ++ ) {
1336
1342
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1337
1343
const int32x4_t vi = vcvtnq_s32_f32 (v );
@@ -1340,7 +1346,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1340
1346
y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1341
1347
y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1342
1348
y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1349
+
1350
+ accv = vaddq_s32 (accv , vi );
1343
1351
}
1352
+ int32_t sum = vaddvq_s32 (accv );
1353
+ y [i ].s = d * sum ;
1344
1354
}
1345
1355
#elif defined(__AVX2__ ) || defined(__AVX__ )
1346
1356
for (int i = 0 ; i < nb ; i ++ ) {
@@ -1388,6 +1398,16 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1388
1398
__m256i i3 = _mm256_cvtps_epi32 ( v3 );
1389
1399
1390
1400
#if defined(__AVX2__ )
1401
+
1402
+ // Compute the sum of the quants
1403
+ // There is not better way of doing this???
1404
+ __m256i acc = _mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 ));
1405
+ __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (acc ), _mm256_extracti128_si256 (acc , 1 ));
1406
+ __m128i hi64 = _mm_unpackhi_epi64 (sum128 , sum128 );
1407
+ __m128i sum64 = _mm_add_epi32 (hi64 , sum128 );
1408
+ __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
1409
+ y [i ].s = d * _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
1410
+
1391
1411
// Convert int32 to int16
1392
1412
i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1393
1413
i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1430,6 +1450,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1430
1450
// scalar
1431
1451
quantize_row_q8_0_reference (x , y , k );
1432
1452
#endif
1453
+ #if defined __AVX__
1454
+ // TODO: vectorize this
1455
+ for (int i = 0 ; i < nb ; ++ i ) {
1456
+ int sum = 0 ;
1457
+ for (int l = 0 ; l < QK8_0 ; ++ l ) sum += y [i ].qs [l ];
1458
+ y [i ].s = y [i ].d * sum ;
1459
+ }
1460
+ #endif
1433
1461
}
1434
1462
1435
1463
static void dequantize_row_q4_0 (const void * restrict vx , float * restrict y , int k ) {
@@ -2372,14 +2400,18 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2372
2400
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2373
2401
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2374
2402
2403
+ float sum8 = 0 ;
2404
+
2375
2405
for (int i = 0 ; i < nb ; i += 2 ) {
2376
2406
const block_q4_0 * restrict x0 = & x [i + 0 ];
2377
2407
const block_q4_0 * restrict x1 = & x [i + 1 ];
2378
2408
const block_q8_0 * restrict y0 = & y [i + 0 ];
2379
2409
const block_q8_0 * restrict y1 = & y [i + 1 ];
2380
2410
2411
+ sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
2412
+
2381
2413
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2382
- const int8x16_t s8b = vdupq_n_s8 (0x8 );
2414
+ // const int8x16_t s8b = vdupq_n_s8(0x8);
2383
2415
2384
2416
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2385
2417
const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2391,10 +2423,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2391
2423
const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2392
2424
2393
2425
// sub 8
2394
- const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
2395
- const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
2396
- const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
2397
- const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2426
+ // const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2427
+ // const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2428
+ // const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2429
+ // const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2398
2430
2399
2431
// load y
2400
2432
const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
@@ -2410,21 +2442,31 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2410
2442
2411
2443
#if defined(__ARM_FEATURE_DOTPROD )
2412
2444
// dot product into int32x4_t
2413
- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls ), v0_0hs , v1_0hs );
2414
- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls ), v0_1hs , v1_1hs );
2445
+ //const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2446
+ //const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2447
+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2448
+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2415
2449
2416
2450
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
2417
2451
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
2418
2452
#else
2419
- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
2420
- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
2421
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
2422
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
2453
+ //const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2454
+ //const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2455
+ //const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2456
+ //const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2457
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2458
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
2459
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0h ), vget_low_s8 (v1_0hs ));
2460
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0h ), vget_high_s8 (v1_0hs ));
2423
2461
2424
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
2425
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
2426
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
2427
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
2462
+ //const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2463
+ //const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2464
+ //const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2465
+ //const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2466
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1l ), vget_low_s8 (v1_1ls ));
2467
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1l ), vget_high_s8 (v1_1ls ));
2468
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
2469
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1h ), vget_high_s8 (v1_1hs ));
2428
2470
2429
2471
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2430
2472
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
@@ -2436,7 +2478,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2436
2478
#endif
2437
2479
}
2438
2480
2439
- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2481
+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) - 8 * sum8 ;
2440
2482
#elif defined(__AVX2__ )
2441
2483
// Initialize accumulator with zeros
2442
2484
__m256 acc = _mm256_setzero_ps ();
@@ -2569,12 +2611,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2569
2611
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2570
2612
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2571
2613
2614
+ float summs = 0 ;
2615
+
2572
2616
for (int i = 0 ; i < nb ; i += 2 ) {
2573
2617
const block_q4_1 * restrict x0 = & x [i + 0 ];
2574
2618
const block_q4_1 * restrict x1 = & x [i + 1 ];
2575
2619
const block_q8_0 * restrict y0 = & y [i + 0 ];
2576
2620
const block_q8_0 * restrict y1 = & y [i + 1 ];
2577
2621
2622
+ summs += x0 -> m * y0 -> s + x1 -> m * y1 -> s ;
2623
+
2578
2624
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2579
2625
2580
2626
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
@@ -2598,16 +2644,18 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2598
2644
const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2599
2645
const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2600
2646
2601
- const int16x8_t s0i = vaddq_s16 (
2602
- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls )), vmovl_s8 (vget_high_s8 (v1_0ls ))),
2603
- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs )), vmovl_s8 (vget_high_s8 (v1_0hs ))));
2647
+ // We no longer need this. We have computed the sum of the y quants during quantization,
2648
+ // so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
2649
+ //const int16x8_t s0i = vaddq_s16(
2650
+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2651
+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2604
2652
2605
- const int16x8_t s1i = vaddq_s16 (
2606
- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls )), vmovl_s8 (vget_high_s8 (v1_1ls ))),
2607
- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs )), vmovl_s8 (vget_high_s8 (v1_1hs ))));
2653
+ // const int16x8_t s1i = vaddq_s16(
2654
+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2655
+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608
2656
2609
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s0i ), vget_high_s16 (s0i ))), x0 -> m * y0 -> d );
2610
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s1i ), vget_high_s16 (s1i ))), x1 -> m * y1 -> d );
2657
+ // sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2658
+ // sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2611
2659
2612
2660
#if defined(__ARM_FEATURE_DOTPROD )
2613
2661
// dot product into int32x4_t
@@ -2637,24 +2685,28 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2637
2685
#endif
2638
2686
}
2639
2687
2640
- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2688
+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) + summs ;
2641
2689
#elif defined(__AVX2__ )
2642
2690
// Initialize accumulator with zeros
2643
2691
__m256 acc = _mm256_setzero_ps ();
2644
2692
2693
+ float summs = 0 ;
2694
+
2645
2695
// Main loop
2646
2696
for (int i = 0 ; i < nb ; ++ i ) {
2647
2697
const float * d0 = & x [i ].d ;
2648
2698
const float * d1 = & y [i ].d ;
2649
- const float * m0 = & x [i ].m ;
2699
+ //const float * m0 = &x[i].m;
2700
+
2701
+ summs += x [i ].m * y [i ].s ;
2650
2702
2651
2703
const __m256 d0v = _mm256_broadcast_ss ( d0 );
2652
2704
const __m256 d1v = _mm256_broadcast_ss ( d1 );
2653
- const __m256 m0v = _mm256_broadcast_ss ( m0 );
2705
+ // const __m256 m0v = _mm256_broadcast_ss( m0 );
2654
2706
2655
2707
// Compute combined scales
2656
2708
const __m256 d0d1 = _mm256_mul_ps ( d0v , d1v );
2657
- const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
2709
+ // const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2658
2710
2659
2711
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2660
2712
const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
@@ -2677,14 +2729,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2677
2729
// Accumulate d0*d1*x*y
2678
2730
acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
2679
2731
2680
- // Compute sum of y values
2681
- const __m256i y16_l = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
2682
- const __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2683
- const __m256i ysumi = _mm256_madd_epi16 ( _mm256_add_epi16 (y16_l , y16_h ), ones );
2684
- const __m256 ysum = _mm256_cvtepi32_ps ( ysumi );
2732
+ // We no longer need this. We have computed the sum of the y quants during quantization,
2733
+ // so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
2734
+ //// Compute sum of y values
2735
+ //const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2736
+ //const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2737
+ //const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2738
+ //const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685
2739
2686
- // Accumulate d1*m0*y
2687
- acc = _mm256_fmadd_ps ( d1m0 , ysum , acc );
2740
+ //// Accumulate d1*m0*y
2741
+ // acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2688
2742
}
2689
2743
2690
2744
// Return horizontal sum of the acc vector
@@ -2693,7 +2747,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2693
2747
res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2694
2748
res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2695
2749
2696
- sumf = _mm_cvtss_f32 ( res );
2750
+ sumf = _mm_cvtss_f32 ( res ) + summs ;
2697
2751
#else
2698
2752
// scalar
2699
2753
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments