@@ -1758,7 +1758,37 @@ static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int
1758
1758
int8_t * restrict qs = vy ;
1759
1759
float * restrict ds = (float * ) ((uint8_t * ) vy + nb * QK8_0C );
1760
1760
1761
- #if __AVX512F__
1761
+ #if defined(__ARM_NEON )
1762
+ for (int i = 0 ; i < nb ; i ++ ) {
1763
+ float32x4_t srcv [8 ];
1764
+ float32x4_t asrcv [8 ];
1765
+ float32x4_t amaxv [8 ];
1766
+
1767
+ for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
1768
+ for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vabsq_f32 (srcv [l ]);
1769
+
1770
+ for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vmaxq_f32 (asrcv [2 * l ], asrcv [2 * l + 1 ]);
1771
+ for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
1772
+ for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
1773
+
1774
+ const float amax = vmaxvq_f32 (amaxv [0 ]);
1775
+
1776
+ const float d = amax / ((1 << 7 ) - 1 );
1777
+ const float id = d ? 1.0f /d : 0.0f ;
1778
+
1779
+ ds [i ] = d ;
1780
+
1781
+ for (int l = 0 ; l < 8 ; l ++ ) {
1782
+ const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1783
+ const int32x4_t vi = vcvtnq_s32_f32 (v );
1784
+
1785
+ qs [i * QK8_0C + 4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
1786
+ qs [i * QK8_0C + 4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1787
+ qs [i * QK8_0C + 4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1788
+ qs [i * QK8_0C + 4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1789
+ }
1790
+ }
1791
+ #elif defined(__AVX512F__ )
1762
1792
for (int i = 0 ; i < nb ; i ++ ) {
1763
1793
const __m512 x0 = _mm512_loadu_ps ( x + i * QK8_0C );
1764
1794
const __m512 x1 = _mm512_loadu_ps ( x + i * QK8_0C + QK8_0C /2 );
@@ -3095,7 +3125,69 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
3095
3125
3096
3126
float sumf = 0.0 ;
3097
3127
3098
- #if __AVX512F__
3128
+ #if defined(__ARM_NEON )
3129
+ float32x4_t sumv0 = vdupq_n_f32 (0.0f );
3130
+ float32x4_t sumv1 = vdupq_n_f32 (0.0f );
3131
+
3132
+ for (int i = 0 ; i < nb /2 ; i ++ ) {
3133
+ const int dst0 = i + i /2 * 2 ; // 0, 1, 4, 5, 8, 9, ...
3134
+ const int dst1 = i + i /2 * 2 + 2 ; // 2, 3, 6, 7, 10, 11 ...
3135
+
3136
+ const uint8x16_t m4b = vdupq_n_u8 (0xf );
3137
+ const int8x16_t s8b = vdupq_n_s8 (0x8 );
3138
+
3139
+ const uint8x16_t v0_01l = vld1q_u8 (& xqs [i * QK4_0 ]);
3140
+ const uint8x16_t v0_01h = vld1q_u8 (& xqs [i * QK4_0 + QK4_0 /2 ]);
3141
+
3142
+ // 4-bit -> 8-bit
3143
+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_01l , m4b ));
3144
+ const int8x16_t v0_0h = vreinterpretq_s8_u8 (vandq_u8 (v0_01h , m4b ));
3145
+ const int8x16_t v0_1l = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_01l , 4 ));
3146
+ const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_01h , 4 ));
3147
+
3148
+ // sub 8
3149
+ const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
3150
+ const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
3151
+ const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
3152
+ const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
3153
+
3154
+ // load y
3155
+ const int8x16_t v1_0l = vld1q_s8 (& yqs [dst0 * QK8_0C ]);
3156
+ const int8x16_t v1_0h = vld1q_s8 (& yqs [dst0 * QK8_0C + 16 ]);
3157
+ const int8x16_t v1_1l = vld1q_s8 (& yqs [dst1 * QK8_0C ]);
3158
+ const int8x16_t v1_1h = vld1q_s8 (& yqs [dst1 * QK8_0C + 16 ]);
3159
+
3160
+ #if defined(__ARM_FEATURE_DOTPROD )
3161
+ // dot product into int32x4_t
3162
+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0l ), v0_0hs , v1_0h );
3163
+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1l ), v0_1hs , v1_1h );
3164
+
3165
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), xds [dst0 ]* yds [dst0 ]);
3166
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), xds [dst1 ]* yds [dst1 ]);
3167
+ #else
3168
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0l ));
3169
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0l ));
3170
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0h ));
3171
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0h ));
3172
+
3173
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1l ));
3174
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1l ));
3175
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1h ));
3176
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1h ));
3177
+
3178
+ const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
3179
+ const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
3180
+ const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
3181
+ const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
3182
+
3183
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (pl0 , ph0 )), xds [dst0 ]* yds [dst0 ]);
3184
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (pl1 , ph1 )), xds [dst1 ]* yds [dst1 ]);
3185
+ #endif
3186
+ }
3187
+
3188
+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
3189
+
3190
+ #elif defined(__AVX512F__ )
3099
3191
// Initialize accumulator with zeros
3100
3192
__m512 acc = _mm512_setzero_ps ();
3101
3193
for (int i = 0 ; i < nb ; i += 4 ) {
0 commit comments