@@ -3079,32 +3079,50 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
3079
3079
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
3080
3080
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
3081
3081
3082
- for (int i = 0 ; i < nb ; ++ i ) {
3083
- const block_q8_0 * restrict x0 = & x [i ];
3084
- const block_q8_0 * restrict y0 = & y [i ];
3082
+ for (int i = 0 ; i < nb ; i += 2 ) {
3083
+ const block_q8_0 * restrict x0 = & x [i + 0 ];
3084
+ const block_q8_0 * restrict x1 = & x [i + 1 ];
3085
+ const block_q8_0 * restrict y0 = & y [i + 0 ];
3086
+ const block_q8_0 * restrict y1 = & y [i + 1 ];
3085
3087
3086
- const int8x16_t v0_0 = vld1q_s8 (x0 -> qs );
3087
- const int8x16_t v0_1 = vld1q_s8 (x0 -> qs + 16 );
3088
+ const int8x16_t x0_0 = vld1q_s8 (x0 -> qs );
3089
+ const int8x16_t x0_1 = vld1q_s8 (x0 -> qs + 16 );
3090
+ const int8x16_t x1_0 = vld1q_s8 (x1 -> qs );
3091
+ const int8x16_t x1_1 = vld1q_s8 (x1 -> qs + 16 );
3088
3092
3089
3093
// load y
3090
- const int8x16_t v1_0 = vld1q_s8 (y0 -> qs );
3091
- const int8x16_t v1_1 = vld1q_s8 (y0 -> qs + 16 );
3094
+ const int8x16_t y0_0 = vld1q_s8 (y0 -> qs );
3095
+ const int8x16_t y0_1 = vld1q_s8 (y0 -> qs + 16 );
3096
+ const int8x16_t y1_0 = vld1q_s8 (y1 -> qs );
3097
+ const int8x16_t y1_1 = vld1q_s8 (y1 -> qs + 16 );
3092
3098
3093
3099
#if defined(__ARM_FEATURE_DOTPROD )
3094
3100
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (
3095
- vdotq_s32 (vdupq_n_s32 (0 ), v0_0 , v1_0 ),
3096
- vdotq_s32 (vdupq_n_s32 (0 ), v0_1 , v1_1 ))), x0 -> d * y0 -> d );
3097
- #else
3098
- const int16x8_t p0l = vmull_s8 (vget_low_s8 (v0_0 ), vget_low_s8 (v1_0 ));
3099
- const int16x8_t p0h = vmull_s8 (vget_high_s8 (v0_0 ), vget_high_s8 (v1_0 ));
3100
- const int16x8_t p1l = vmull_s8 (vget_low_s8 (v0_1 ), vget_low_s8 (v1_1 ));
3101
- const int16x8_t p1h = vmull_s8 (vget_high_s8 (v0_1 ), vget_high_s8 (v1_1 ));
3101
+ vdotq_s32 (vdupq_n_s32 (0 ), x0_0 , y0_0 ),
3102
+ vdotq_s32 (vdupq_n_s32 (0 ), x0_1 , y0_1 ))), x0 -> d * y0 -> d );
3102
3103
3103
- const int32x4_t pl = vaddq_s32 (vpaddlq_s16 (p0l ), vpaddlq_s16 (p0h ));
3104
- const int32x4_t ph = vaddq_s32 (vpaddlq_s16 (p1l ), vpaddlq_s16 (p1h ));
3104
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (
3105
+ vdotq_s32 (vdupq_n_s32 (0 ), x1_0 , y1_0 ),
3106
+ vdotq_s32 (vdupq_n_s32 (0 ), x1_1 , y1_1 ))), x1 -> d * y1 -> d );
3105
3107
3106
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl ), x0 -> d * y0 -> d );
3107
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph ), x0 -> d * y0 -> d );
3108
+ #else
3109
+ const int16x8_t p0_0 = vmull_s8 (vget_low_s8 (x0_0 ), vget_low_s8 (y0_0 ));
3110
+ const int16x8_t p0_1 = vmull_s8 (vget_high_s8 (x0_0 ), vget_high_s8 (y0_0 ));
3111
+ const int16x8_t p0_2 = vmull_s8 (vget_low_s8 (x0_1 ), vget_low_s8 (y0_1 ));
3112
+ const int16x8_t p0_3 = vmull_s8 (vget_high_s8 (x0_1 ), vget_high_s8 (y0_1 ));
3113
+
3114
+ const int16x8_t p1_0 = vmull_s8 (vget_low_s8 (x1_0 ), vget_low_s8 (y1_0 ));
3115
+ const int16x8_t p1_1 = vmull_s8 (vget_high_s8 (x1_0 ), vget_high_s8 (y1_0 ));
3116
+ const int16x8_t p1_2 = vmull_s8 (vget_low_s8 (x1_1 ), vget_low_s8 (y1_1 ));
3117
+ const int16x8_t p1_3 = vmull_s8 (vget_high_s8 (x1_1 ), vget_high_s8 (y1_1 ));
3118
+
3119
+ const int32x4_t p0 = vaddq_s32 (vpaddlq_s16 (p0_0 ), vpaddlq_s16 (p0_1 ));
3120
+ const int32x4_t p1 = vaddq_s32 (vpaddlq_s16 (p0_2 ), vpaddlq_s16 (p0_3 ));
3121
+ const int32x4_t p2 = vaddq_s32 (vpaddlq_s16 (p1_0 ), vpaddlq_s16 (p1_1 ));
3122
+ const int32x4_t p3 = vaddq_s32 (vpaddlq_s16 (p1_2 ), vpaddlq_s16 (p1_3 ));
3123
+
3124
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (p0 , p1 )), x0 -> d * y0 -> d );
3125
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (p2 , p3 )), x1 -> d * y1 -> d );
3108
3126
#endif
3109
3127
}
3110
3128
0 commit comments