@@ -5667,16 +5667,21 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
5667
5667
5668
5668
const int nb = n / QK_K;
5669
5669
5670
- #if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
5670
+ #if defined( __ARM_NEON)
5671
5671
float sumf = 0.0f;
5672
5672
5673
5673
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
5674
5674
5675
5675
const uint8x16_t shift = vld1q_u8(k_shift);
5676
5676
5677
5677
for (int i = 0; i < nb; ++i) {
5678
+ #if defined(__ARM_FEATURE_DOTPROD)
5678
5679
int32x4_t sumi0 = vdupq_n_s32(0);
5679
5680
int32x4_t sumi1 = vdupq_n_s32(0);
5681
+ #else
5682
+ int16x8_t sumi0 = vdupq_n_s16(0);
5683
+ int16x8_t sumi1 = vdupq_n_s16(0);
5684
+ #endif
5680
5685
5681
5686
// first 32 bytes of 5 elements
5682
5687
{
@@ -5714,6 +5719,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
5714
5719
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
5715
5720
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
5716
5721
5722
+ #if defined(__ARM_FEATURE_DOTPROD)
5717
5723
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5718
5724
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5719
5725
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
@@ -5724,103 +5730,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
5724
5730
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
5725
5731
sumi0 = vdotq_s32(sumi0, sqx8, qy8);
5726
5732
sumi1 = vdotq_s32(sumi1, sqx9, qy9);
5727
- }
5728
-
5729
- // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
5730
- {
5731
- uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
5732
- uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
5733
- uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
5734
- uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
5735
- uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
5736
- uint32_t qh;
5737
- memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
5738
- uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
5739
- qx5 = vmulq_u8(qx5, shift);
5740
-
5741
- // multiply by 3 and keep the 2 bits above 8 bits
5742
- int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5743
- int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5744
- int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5745
- int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5746
- int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5747
- int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5748
-
5749
- const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
5750
- const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
5751
- const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
5752
- const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
5753
- const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
5754
- const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
5755
-
5756
- sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5757
- sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5758
- sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5759
- sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5760
- sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5761
- sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5762
- }
5763
-
5764
- const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
5765
- const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
5766
-
5767
- sumi0 = vaddq_s32(sumi0, sumi1);
5768
- sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
5769
-
5770
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5771
-
5772
- sumf += d * (float) vaddvq_s32(sumi0);
5773
- }
5774
-
5775
- *s = sumf;
5776
-
5777
- #elif defined __ARM_NEON
5778
- float sumf = 0.0f;
5779
-
5780
- uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
5781
-
5782
- const uint8x16_t shift = vld1q_u8(k_shift);
5783
-
5784
- for (int i = 0; i < nb; ++i) {
5785
- int16x8_t sumi0 = vdupq_n_s16(0);
5786
- int16x8_t sumi1 = vdupq_n_s16(0);
5787
-
5788
- // first 32 bytes of 5 elements
5789
- {
5790
- uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
5791
- uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
5792
- uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
5793
- uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
5794
- uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
5795
- uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
5796
- uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
5797
- uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
5798
- uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
5799
- uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
5800
-
5801
- // multiply by 3 and keep the 2 bits above 8 bits
5802
- int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5803
- int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5804
- int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5805
- int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5806
- int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5807
- int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5808
- int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
5809
- int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
5810
- int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
5811
- int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
5812
-
5813
- const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
5814
- const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
5815
- const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
5816
- const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
5817
- const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
5818
- const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
5819
- const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
5820
- const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
5821
- const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
5822
- const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
5823
-
5733
+ #else
5824
5734
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
5825
5735
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
5826
5736
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
@@ -5841,6 +5751,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
5841
5751
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
5842
5752
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
5843
5753
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
5754
+ #endif
5844
5755
}
5845
5756
5846
5757
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
@@ -5870,6 +5781,14 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
5870
5781
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
5871
5782
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
5872
5783
5784
+ #if defined(__ARM_FEATURE_DOTPROD)
5785
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5786
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5787
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5788
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5789
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5790
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5791
+ #else
5873
5792
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
5874
5793
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
5875
5794
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
@@ -5882,22 +5801,30 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
5882
5801
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
5883
5802
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
5884
5803
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
5804
+ #endif
5885
5805
}
5886
5806
5887
5807
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
5888
5808
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
5889
5809
5810
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5811
+
5812
+ #if defined(__ARM_FEATURE_DOTPROD)
5813
+ sumi0 = vaddq_s32(sumi0, sumi1);
5814
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
5815
+
5816
+ sumf += d * (float) vaddvq_s32(sumi0);
5817
+ #else
5890
5818
sumi0 = vaddq_s16(sumi0, sumi1);
5891
5819
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
5892
5820
5893
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5894
-
5895
5821
sumf += d * (float) vaddlvq_s16(sumi0);
5822
+ #endif
5896
5823
}
5897
5824
5898
5825
*s = sumf;
5899
5826
5900
- #elif defined __AVX2__
5827
+ #elif defined( __AVX2__)
5901
5828
__m256 sumf = _mm256_setzero_ps();
5902
5829
5903
5830
for (int i = 0; i < nb; ++i) {
@@ -6063,14 +5990,19 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
6063
5990
6064
5991
const int nb = n / QK_K;
6065
5992
6066
- #if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
5993
+ #if defined( __ARM_NEON)
6067
5994
float sumf = 0.0f;
6068
5995
6069
5996
const uint8x16_t m3 = vdupq_n_u8(3);
6070
5997
6071
5998
for (int i = 0; i < nb; ++i) {
5999
+ #if defined(__ARM_FEATURE_DOTPROD)
6072
6000
int32x4_t sumi0 = vdupq_n_s32(0);
6073
6001
int32x4_t sumi1 = vdupq_n_s32(0);
6002
+ #else
6003
+ int16x8_t sumi0 = vdupq_n_s16(0);
6004
+ int16x8_t sumi1 = vdupq_n_s16(0);
6005
+ #endif
6074
6006
6075
6007
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
6076
6008
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
@@ -6100,6 +6032,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
6100
6032
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
6101
6033
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
6102
6034
6035
+ #if defined(__ARM_FEATURE_DOTPROD)
6103
6036
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
6104
6037
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
6105
6038
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
@@ -6108,58 +6041,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
6108
6041
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
6109
6042
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
6110
6043
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
6111
- }
6112
-
6113
- const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
6114
- const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
6115
-
6116
- sumi0 = vaddq_s32(sumi0, sumi1);
6117
- sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
6118
-
6119
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6120
-
6121
- sumf += d * (float) vaddvq_s32(sumi0);
6122
- }
6123
-
6124
- *s = sumf;
6125
-
6126
- #elif defined __ARM_NEON
6127
- float sumf = 0.0f;
6128
-
6129
- const uint8x16_t m3 = vdupq_n_u8(3);
6130
-
6131
- for (int i = 0; i < nb; ++i) {
6132
- int16x8_t sumi0 = vdupq_n_s16(0);
6133
- int16x8_t sumi1 = vdupq_n_s16(0);
6134
-
6135
- for (size_t j = 0; j < sizeof(x->qs); j += 32) {
6136
- uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
6137
- uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
6138
- uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
6139
- uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
6140
- uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
6141
- uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
6142
- uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
6143
- uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
6144
-
6145
- int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
6146
- int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
6147
- int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
6148
- int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
6149
- int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
6150
- int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
6151
- int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
6152
- int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
6153
-
6154
- const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
6155
- const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
6156
- const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
6157
- const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
6158
- const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
6159
- const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
6160
- const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
6161
- const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
6162
-
6044
+ #else
6163
6045
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
6164
6046
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
6165
6047
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
@@ -6176,22 +6058,30 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
6176
6058
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
6177
6059
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
6178
6060
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
6061
+ #endif
6179
6062
}
6180
6063
6181
6064
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
6182
6065
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
6183
6066
6067
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6068
+
6069
+ #if defined(__ARM_FEATURE_DOTPROD)
6070
+ sumi0 = vaddq_s32(sumi0, sumi1);
6071
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
6072
+
6073
+ sumf += d * (float) vaddvq_s32(sumi0);
6074
+ #else
6184
6075
sumi0 = vaddq_s16(sumi0, sumi1);
6185
6076
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
6186
6077
6187
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6188
-
6189
6078
sumf += d * (float) vaddlvq_s16(sumi0);
6079
+ #endif
6190
6080
}
6191
6081
6192
6082
*s = sumf;
6193
6083
6194
- #elif defined __AVX2__
6084
+ #elif defined( __AVX2__)
6195
6085
__m256 sumf = _mm256_setzero_ps();
6196
6086
6197
6087
for (int i = 0; i < nb; ++i) {
0 commit comments