@@ -1829,7 +1829,7 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
1829
1829
static void ggml_vec_dot_q4_1_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1830
1830
static void ggml_vec_dot_q4_2_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1831
1831
static void ggml_vec_dot_q4_3_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1832
- static void ggml_vec_dot_q8_0_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1832
+ static void ggml_vec_dot_q8_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1833
1833
1834
1834
static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
1835
1835
[GGML_TYPE_Q4_0 ] = {
@@ -1864,8 +1864,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1864
1864
.dequantize_row_q = dequantize_row_q8_0 ,
1865
1865
.quantize_row_q = quantize_row_q8_0 ,
1866
1866
.quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q8_0_reference ,
1867
- .quantize_row_q_dot = quantize_row_q8_1 ,
1868
- .vec_dot_q = ggml_vec_dot_q8_0_q8_1 ,
1867
+ .quantize_row_q_dot = quantize_row_q8_0 ,
1868
+ .vec_dot_q = ggml_vec_dot_q8_0_q8_0 ,
1869
1869
},
1870
1870
[GGML_TYPE_Q8_1 ] = {
1871
1871
.dequantize_row_q = NULL , // TODO
@@ -3062,23 +3062,23 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
3062
3062
#endif
3063
3063
}
3064
3064
3065
- static void ggml_vec_dot_q8_0_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
3066
- const int nb = n / QK8_1 ;
3065
+ static void ggml_vec_dot_q8_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
3066
+ const int nb = n / QK8_0 ;
3067
3067
3068
- assert (n % QK8_1 == 0 );
3068
+ assert (n % QK8_0 == 0 );
3069
3069
assert (nb % 2 == 0 );
3070
- assert (QK8_1 == QK8_0 );
3070
+ assert (QK8_0 == QK8_0 );
3071
3071
3072
3072
const block_q8_0 * restrict x = vx ;
3073
- const block_q8_1 * restrict y = vy ;
3073
+ const block_q8_0 * restrict y = vy ;
3074
3074
3075
- #if defined(__ARM_NEON_XXX )
3075
+ #if defined(__ARM_NEON )
3076
3076
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
3077
3077
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
3078
3078
3079
3079
for (int i = 0 ; i < nb ; ++ i ) {
3080
3080
const block_q8_0 * restrict x0 = & x [i ];
3081
- const block_q8_1 * restrict y0 = & y [i ];
3081
+ const block_q8_0 * restrict y0 = & y [i ];
3082
3082
3083
3083
const int8x16_t v0_0 = vld1q_s8 (x0 -> qs );
3084
3084
const int8x16_t v0_1 = vld1q_s8 (x0 -> qs + 16 );
@@ -3096,28 +3096,16 @@ static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void *
3096
3096
vdotq_s32 (vdupq_n_s32 (0 ), v0_0 , v1_1 ),
3097
3097
vdotq_s32 (vdupq_n_s32 (0 ), v0_1 , v1_0 ))), x0 -> d * y0 -> d );
3098
3098
#else
3099
- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
3100
- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
3101
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
3102
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
3099
+ const int16x8_t p0l = vmull_s8 (vget_low_s8 (v0_0 ), vget_low_s8 (v1_0 ));
3100
+ const int16x8_t p0h = vmull_s8 (vget_high_s8 (v0_0 ), vget_high_s8 (v1_0 ));
3101
+ const int16x8_t p1l = vmull_s8 (vget_low_s8 (v0_1 ), vget_low_s8 (v1_1 ));
3102
+ const int16x8_t p1h = vmull_s8 (vget_high_s8 (v0_1 ), vget_high_s8 (v1_1 ));
3103
3103
3104
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
3105
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
3106
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
3107
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
3104
+ const int32x4_t pl = vaddq_s32 (vpaddlq_s16 (p0l ), vpaddlq_s16 (p0h ));
3105
+ const int32x4_t ph = vaddq_s32 (vpaddlq_s16 (p1l ), vpaddlq_s16 (p1h ));
3108
3106
3109
- const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
3110
- const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
3111
- const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
3112
- const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
3113
-
3114
- sumv0 = vmlaq_n_f32 (sumv0 , vaddq_f32 (
3115
- vmulq_n_f32 (vcvtq_f32_s32 (pl0 ), GGML_FP16_TO_FP32 (x0_0 -> d )),
3116
- vmulq_n_f32 (vcvtq_f32_s32 (ph0 ), GGML_FP16_TO_FP32 (x0_1 -> d ))), y0 -> d );
3117
-
3118
- sumv1 = vmlaq_n_f32 (sumv1 , vaddq_f32 (
3119
- vmulq_n_f32 (vcvtq_f32_s32 (pl1 ), GGML_FP16_TO_FP32 (x1_0 -> d )),
3120
- vmulq_n_f32 (vcvtq_f32_s32 (ph1 ), GGML_FP16_TO_FP32 (x1_1 -> d ))), y1 -> d );
3107
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl ), x0 -> d * y0 -> d );
3108
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph ), x0 -> d * y0 -> d );
3121
3109
#endif
3122
3110
}
3123
3111
@@ -3132,7 +3120,7 @@ static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void *
3132
3120
3133
3121
int sumi = 0 ;
3134
3122
3135
- for (int j = 0 ; j < QK8_1 ; j ++ ) {
3123
+ for (int j = 0 ; j < QK8_0 ; j ++ ) {
3136
3124
const int v0 = x0 [j ];
3137
3125
const int v1 = y0 [j ];
3138
3126
0 commit comments