@@ -1310,6 +1310,29 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1310
1310
}
1311
1311
}
1312
1312
1313
+ #ifdef __AVX2__
1314
+ // There is no better way of doing this?
1315
+ // I guess not, AVX is not very good at horizontal sums.
1316
+ // The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
1317
+ // faster than the solution below. As I don't have an AVX2 system handt right now to test,
1318
+ // keeping the original.
1319
+ // TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
1320
+ //static inline float horizontal_sum(__m256i a) {
1321
+ // __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
1322
+ // __m256i sum = _mm256_add_epi32(a, b);
1323
+ // __m256i hi = _mm256_unpackhi_epi64(sum, sum);
1324
+ // sum = _mm256_add_epi32(sum, hi);
1325
+ // return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
1326
+ //}
1327
+ static inline float horizontal_sum (__m256i a ) {
1328
+ __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (a ), _mm256_extracti128_si256 (a , 1 ));
1329
+ __m128i hi64 = _mm_unpackhi_epi64 (sum128 , sum128 );
1330
+ __m128i sum64 = _mm_add_epi32 (hi64 , sum128 );
1331
+ __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
1332
+ return _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
1333
+ }
1334
+ #endif
1335
+
1313
1336
static void quantize_row_q8_0 (const float * restrict x , void * restrict vy , int k ) {
1314
1337
assert (k % QK8_0 == 0 );
1315
1338
const int nb = k / QK8_0 ;
@@ -1399,14 +1422,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1399
1422
1400
1423
#if defined(__AVX2__ )
1401
1424
1402
- // Compute the sum of the quants
1403
- // There is not better way of doing this???
1404
- __m256i acc = _mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 ));
1405
- __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (acc ), _mm256_extracti128_si256 (acc , 1 ));
1406
- __m128i hi64 = _mm_unpackhi_epi64 (sum128 , sum128 );
1407
- __m128i sum64 = _mm_add_epi32 (hi64 , sum128 );
1408
- __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
1409
- y [i ].s = d * _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
1425
+ // Compute the sum of the quants and set y[i].s
1426
+ y [i ].s = d * horizontal_sum (_mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 )));
1410
1427
1411
1428
// Convert int32 to int16
1412
1429
i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@@ -2411,7 +2428,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2411
2428
sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
2412
2429
2413
2430
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2414
- //const int8x16_t s8b = vdupq_n_s8(0x8);
2415
2431
2416
2432
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2417
2433
const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2422,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2422
2438
const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2423
2439
const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2424
2440
2425
- // sub 8
2426
- //const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2427
- //const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2428
- //const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2429
- //const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2430
-
2431
2441
// load y
2432
2442
const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2433
2443
const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
@@ -2442,27 +2452,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2442
2452
2443
2453
#if defined(__ARM_FEATURE_DOTPROD )
2444
2454
// dot product into int32x4_t
2445
- //const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2446
- //const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2447
2455
const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2448
2456
const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2449
2457
2450
2458
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
2451
2459
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
2452
2460
#else
2453
- //const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2454
- //const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2455
- //const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2456
- //const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2457
2461
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2458
2462
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
2459
2463
const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0h ), vget_low_s8 (v1_0hs ));
2460
2464
const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0h ), vget_high_s8 (v1_0hs ));
2461
2465
2462
- //const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2463
- //const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2464
- //const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2465
- //const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2466
2466
const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1l ), vget_low_s8 (v1_1ls ));
2467
2467
const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1l ), vget_high_s8 (v1_1ls ));
2468
2468
const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
@@ -2644,19 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2644
2644
const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2645
2645
const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2646
2646
2647
- // We no longer need this. We have computed the sum of the y quants during quantization,
2648
- // so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
2649
- //const int16x8_t s0i = vaddq_s16(
2650
- // vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2651
- // vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2652
-
2653
- //const int16x8_t s1i = vaddq_s16(
2654
- // vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2655
- // vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2656
-
2657
- //sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2658
- //sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2659
-
2660
2647
#if defined(__ARM_FEATURE_DOTPROD )
2661
2648
// dot product into int32x4_t
2662
2649
const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
@@ -2702,11 +2689,9 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2702
2689
2703
2690
const __m256 d0v = _mm256_broadcast_ss ( d0 );
2704
2691
const __m256 d1v = _mm256_broadcast_ss ( d1 );
2705
- //const __m256 m0v = _mm256_broadcast_ss( m0 );
2706
2692
2707
2693
// Compute combined scales
2708
2694
const __m256 d0d1 = _mm256_mul_ps ( d0v , d1v );
2709
- //const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2710
2695
2711
2696
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2712
2697
const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
@@ -2728,17 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2728
2713
2729
2714
// Accumulate d0*d1*x*y
2730
2715
acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
2731
-
2732
- // We no longer need this. We have computed the sum of the y quants during quantization,
2733
- // so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
2734
- //// Compute sum of y values
2735
- //const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2736
- //const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2737
- //const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2738
- //const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2739
-
2740
- //// Accumulate d1*m0*y
2741
- //acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2742
2716
}
2743
2717
2744
2718
// Return horizontal sum of the acc vector
0 commit comments