@@ -2875,77 +2875,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2875
2875
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2876
2876
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2877
2877
2878
- float summs = 0.0f ;
2878
+ float summs0 = 0.0f ;
2879
+ float summs1 = 0.0f ;
2879
2880
2880
- for (int i = 0 ; i < nb ; i += 2 ) {
2881
+ for (int i = 0 ; i < nb ; ++ i ) {
2881
2882
const block_q4_3 * restrict x0_0 = & x [2 * (i + 0 ) + 0 ];
2882
2883
const block_q4_3 * restrict x0_1 = & x [2 * (i + 0 ) + 1 ];
2883
- const block_q4_3 * restrict x1_0 = & x [2 * (i + 1 ) + 0 ];
2884
- const block_q4_3 * restrict x1_1 = & x [2 * (i + 1 ) + 1 ];
2885
2884
2886
2885
const block_q8_0 * restrict y0 = & y [i + 0 ];
2887
- const block_q8_0 * restrict y1 = & y [i + 1 ];
2888
2886
2889
- summs += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 + GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
2890
- summs += GGML_FP16_TO_FP32 (x1_0 -> m ) * y1 -> s0 + GGML_FP16_TO_FP32 (x1_1 -> m ) * y1 -> s1 ;
2891
-
2892
- const uint8x16_t m4b = vdupq_n_u8 (0xf );
2893
-
2894
- const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2895
- const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
2896
- const float x1_0d = GGML_FP16_TO_FP32 (x1_0 -> d );
2897
- const float x1_1d = GGML_FP16_TO_FP32 (x1_1 -> d );
2887
+ summs0 += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 ;
2888
+ summs1 += GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
2898
2889
2899
2890
const uint8x16_t v0_0 = vcombine_u8 (vld1_u8 (x0_0 -> qs ), vld1_u8 (x0_1 -> qs ));
2900
- const uint8x16_t v0_1 = vcombine_u8 (vld1_u8 (x1_0 -> qs ), vld1_u8 (x1_1 -> qs ));
2901
2891
2902
2892
// 4-bit -> 8-bit
2903
- const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
2893
+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , vdupq_n_u8 ( 0xf ) ));
2904
2894
const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
2905
- const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2906
- const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2907
2895
2908
2896
// interleave
2909
2897
const int8x16_t v0_0lz = vzip1q_s8 (v0_0l , v0_0h );
2910
2898
const int8x16_t v0_0hz = vzip2q_s8 (v0_0l , v0_0h );
2911
- const int8x16_t v0_1lz = vzip1q_s8 (v0_1l , v0_1h );
2912
- const int8x16_t v0_1hz = vzip2q_s8 (v0_1l , v0_1h );
2913
2899
2914
2900
// load y
2915
2901
const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2916
2902
const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2917
- const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2918
- const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2903
+
2904
+ const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2905
+ const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
2919
2906
2920
2907
#if defined(__ARM_FEATURE_DOTPROD )
2921
2908
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), x0_0d * y0 -> d );
2922
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
2923
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l )), x1_0d * y1 -> d );
2924
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h )), x1_1d * y1 -> d );
2909
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
2925
2910
#else
2926
2911
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
2927
2912
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
2928
2913
const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
2929
2914
const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
2930
2915
2931
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
2932
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
2933
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
2934
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
2935
-
2936
2916
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2937
2917
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2938
- const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2939
- const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
2940
2918
2941
2919
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl0 ), x0_0d * y0 -> d );
2942
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
2943
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (pl1 ), x1_0d * y1 -> d );
2944
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph1 ), x1_1d * y1 -> d );
2920
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
2945
2921
#endif
2946
2922
}
2947
2923
2948
- sumf = vaddvq_f32 (vaddq_f32 (sumv0 , sumv1 )) + summs ;
2924
+ * s = vaddvq_f32 (vaddq_f32 (sumv0 , sumv1 )) + summs0 + summs1 ;
2949
2925
#elif defined(__AVX2__ )
2950
2926
// Initialize accumulator with zeros
2951
2927
__m256 acc = _mm256_setzero_ps ();
0 commit comments