@@ -1825,9 +1825,9 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
1825
1825
}
1826
1826
}
1827
1827
1828
- static void ggml_vec_dot_q4_0_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1828
+ static void ggml_vec_dot_q4_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1829
1829
static void ggml_vec_dot_q4_1_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1830
- static void ggml_vec_dot_q4_2_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1830
+ static void ggml_vec_dot_q4_2_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1831
1831
static void ggml_vec_dot_q4_3_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1832
1832
static void ggml_vec_dot_q8_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1833
1833
@@ -1837,7 +1837,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1837
1837
.quantize_row_q = quantize_row_q4_0 ,
1838
1838
.quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_0_reference ,
1839
1839
.quantize_row_q_dot = quantize_row_q8_1 ,
1840
- .vec_dot_q = ggml_vec_dot_q4_0_q8_1 ,
1840
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0 ,
1841
1841
},
1842
1842
[GGML_TYPE_Q4_1 ] = {
1843
1843
.dequantize_row_q = dequantize_row_q4_1 ,
@@ -1851,7 +1851,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1851
1851
.quantize_row_q = quantize_row_q4_2 ,
1852
1852
.quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_2_reference ,
1853
1853
.quantize_row_q_dot = quantize_row_q8_1 ,
1854
- .vec_dot_q = ggml_vec_dot_q4_2_q8_1 ,
1854
+ .vec_dot_q = ggml_vec_dot_q4_2_q8_0 ,
1855
1855
},
1856
1856
[GGML_TYPE_Q4_3 ] = {
1857
1857
.dequantize_row_q = dequantize_row_q4_3 ,
@@ -2475,7 +2475,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2475
2475
* s = sumf ;
2476
2476
}
2477
2477
2478
- static void ggml_vec_dot_q4_0_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2478
+ static void ggml_vec_dot_q4_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2479
2479
const int nb = n / QK8_1 ;
2480
2480
2481
2481
assert (n % QK8_1 == 0 );
@@ -2488,17 +2488,14 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
2488
2488
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2489
2489
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2490
2490
2491
- float sum8 = 0 ;
2492
-
2493
2491
for (int i = 0 ; i < nb ; i += 2 ) {
2494
2492
const block_q4_0 * restrict x0 = & x [i + 0 ];
2495
2493
const block_q4_0 * restrict x1 = & x [i + 1 ];
2496
2494
const block_q8_1 * restrict y0 = & y [i + 0 ];
2497
2495
const block_q8_1 * restrict y1 = & y [i + 1 ];
2498
2496
2499
- sum8 += x0 -> d * (y0 -> s0 + y0 -> s1 ) + x1 -> d * (y1 -> s0 + y1 -> s1 );
2500
-
2501
2497
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2498
+ const int8x16_t s8b = vdupq_n_s8 (0x8 );
2502
2499
2503
2500
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2504
2501
const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2509,6 +2506,12 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
2509
2506
const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2510
2507
const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2511
2508
2509
+ // sub 8
2510
+ const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
2511
+ const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
2512
+ const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
2513
+ const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2514
+
2512
2515
// load y
2513
2516
const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2514
2517
const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
@@ -2523,21 +2526,21 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
2523
2526
2524
2527
#if defined(__ARM_FEATURE_DOTPROD )
2525
2528
// dot product into int32x4_t
2526
- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2527
- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2529
+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls ), v0_0hs , v1_0hs );
2530
+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls ), v0_1hs , v1_1hs );
2528
2531
2529
2532
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
2530
2533
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
2531
2534
#else
2532
- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2533
- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
2534
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0h ), vget_low_s8 (v1_0hs ));
2535
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0h ), vget_high_s8 (v1_0hs ));
2535
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
2536
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
2537
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
2538
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
2536
2539
2537
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1l ), vget_low_s8 (v1_1ls ));
2538
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1l ), vget_high_s8 (v1_1ls ));
2539
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
2540
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1h ), vget_high_s8 (v1_1hs ));
2540
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
2541
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
2542
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
2543
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
2541
2544
2542
2545
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2543
2546
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
@@ -2549,7 +2552,7 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
2549
2552
#endif
2550
2553
}
2551
2554
2552
- * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) - 8 * sum8 ;
2555
+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2553
2556
#elif defined(__AVX2__ )
2554
2557
// Initialize accumulator with zeros
2555
2558
__m256 acc = _mm256_setzero_ps ();
@@ -2775,7 +2778,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2775
2778
#endif
2776
2779
}
2777
2780
2778
- static void ggml_vec_dot_q4_2_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2781
+ static void ggml_vec_dot_q4_2_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2779
2782
const int nb = n / QK8_1 ;
2780
2783
2781
2784
assert (n % QK8_1 == 0 );
0 commit comments