Skip to content

Commit 35cc556

Browse files
committed
ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support
1 parent 82b2404 commit 35cc556

File tree

1 file changed

+47
-157
lines changed

1 file changed

+47
-157
lines changed

ggml/src/ggml-quants.c

Lines changed: 47 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -5667,16 +5667,21 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
56675667

56685668
const int nb = n / QK_K;
56695669

5670-
#if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
5670+
#if defined(__ARM_NEON)
56715671
float sumf = 0.0f;
56725672

56735673
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
56745674

56755675
const uint8x16_t shift = vld1q_u8(k_shift);
56765676

56775677
for (int i = 0; i < nb; ++i) {
5678+
#if defined(__ARM_FEATURE_DOTPROD)
56785679
int32x4_t sumi0 = vdupq_n_s32(0);
56795680
int32x4_t sumi1 = vdupq_n_s32(0);
5681+
#else
5682+
int16x8_t sumi0 = vdupq_n_s16(0);
5683+
int16x8_t sumi1 = vdupq_n_s16(0);
5684+
#endif
56805685

56815686
// first 32 bytes of 5 elements
56825687
{
@@ -5714,6 +5719,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
57145719
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
57155720
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
57165721

5722+
#if defined(__ARM_FEATURE_DOTPROD)
57175723
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
57185724
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
57195725
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
@@ -5724,103 +5730,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
57245730
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
57255731
sumi0 = vdotq_s32(sumi0, sqx8, qy8);
57265732
sumi1 = vdotq_s32(sumi1, sqx9, qy9);
5727-
}
5728-
5729-
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
5730-
{
5731-
uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
5732-
uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
5733-
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
5734-
uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
5735-
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
5736-
uint32_t qh;
5737-
memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
5738-
uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
5739-
qx5 = vmulq_u8(qx5, shift);
5740-
5741-
// multiply by 3 and keep the 2 bits above 8 bits
5742-
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5743-
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5744-
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5745-
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5746-
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5747-
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5748-
5749-
const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
5750-
const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
5751-
const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
5752-
const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
5753-
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
5754-
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
5755-
5756-
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5757-
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5758-
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5759-
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5760-
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5761-
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5762-
}
5763-
5764-
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
5765-
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
5766-
5767-
sumi0 = vaddq_s32(sumi0, sumi1);
5768-
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
5769-
5770-
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5771-
5772-
sumf += d * (float) vaddvq_s32(sumi0);
5773-
}
5774-
5775-
*s = sumf;
5776-
5777-
#elif defined __ARM_NEON
5778-
float sumf = 0.0f;
5779-
5780-
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
5781-
5782-
const uint8x16_t shift = vld1q_u8(k_shift);
5783-
5784-
for (int i = 0; i < nb; ++i) {
5785-
int16x8_t sumi0 = vdupq_n_s16(0);
5786-
int16x8_t sumi1 = vdupq_n_s16(0);
5787-
5788-
// first 32 bytes of 5 elements
5789-
{
5790-
uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
5791-
uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
5792-
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
5793-
uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
5794-
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
5795-
uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
5796-
uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
5797-
uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
5798-
uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
5799-
uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
5800-
5801-
// multiply by 3 and keep the 2 bits above 8 bits
5802-
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5803-
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5804-
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5805-
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5806-
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5807-
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5808-
int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
5809-
int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
5810-
int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
5811-
int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
5812-
5813-
const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
5814-
const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
5815-
const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
5816-
const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
5817-
const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
5818-
const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
5819-
const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
5820-
const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
5821-
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
5822-
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
5823-
5733+
#else
58245734
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
58255735
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
58265736
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
@@ -5841,6 +5751,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
58415751
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
58425752
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
58435753
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
5754+
#endif
58445755
}
58455756

58465757
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
@@ -5870,6 +5781,14 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
58705781
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
58715782
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
58725783

5784+
#if defined(__ARM_FEATURE_DOTPROD)
5785+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5786+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5787+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5788+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5789+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5790+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5791+
#else
58735792
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
58745793
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
58755794
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
@@ -5882,22 +5801,30 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
58825801
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
58835802
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
58845803
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
5804+
#endif
58855805
}
58865806

58875807
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
58885808
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
58895809

5810+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5811+
5812+
#if defined(__ARM_FEATURE_DOTPROD)
5813+
sumi0 = vaddq_s32(sumi0, sumi1);
5814+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
5815+
5816+
sumf += d * (float) vaddvq_s32(sumi0);
5817+
#else
58905818
sumi0 = vaddq_s16(sumi0, sumi1);
58915819
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
58925820

5893-
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5894-
58955821
sumf += d * (float) vaddlvq_s16(sumi0);
5822+
#endif
58965823
}
58975824

58985825
*s = sumf;
58995826

5900-
#elif defined __AVX2__
5827+
#elif defined(__AVX2__)
59015828
__m256 sumf = _mm256_setzero_ps();
59025829

59035830
for (int i = 0; i < nb; ++i) {
@@ -6063,14 +5990,19 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
60635990

60645991
const int nb = n / QK_K;
60655992

6066-
#if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
5993+
#if defined(__ARM_NEON)
60675994
float sumf = 0.0f;
60685995

60695996
const uint8x16_t m3 = vdupq_n_u8(3);
60705997

60715998
for (int i = 0; i < nb; ++i) {
5999+
#if defined(__ARM_FEATURE_DOTPROD)
60726000
int32x4_t sumi0 = vdupq_n_s32(0);
60736001
int32x4_t sumi1 = vdupq_n_s32(0);
6002+
#else
6003+
int16x8_t sumi0 = vdupq_n_s16(0);
6004+
int16x8_t sumi1 = vdupq_n_s16(0);
6005+
#endif
60746006

60756007
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
60766008
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
@@ -6100,6 +6032,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
61006032
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
61016033
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
61026034

6035+
#if defined(__ARM_FEATURE_DOTPROD)
61036036
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
61046037
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
61056038
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
@@ -6108,58 +6041,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
61086041
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
61096042
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
61106043
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
6111-
}
6112-
6113-
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
6114-
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
6115-
6116-
sumi0 = vaddq_s32(sumi0, sumi1);
6117-
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
6118-
6119-
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6120-
6121-
sumf += d * (float) vaddvq_s32(sumi0);
6122-
}
6123-
6124-
*s = sumf;
6125-
6126-
#elif defined __ARM_NEON
6127-
float sumf = 0.0f;
6128-
6129-
const uint8x16_t m3 = vdupq_n_u8(3);
6130-
6131-
for (int i = 0; i < nb; ++i) {
6132-
int16x8_t sumi0 = vdupq_n_s16(0);
6133-
int16x8_t sumi1 = vdupq_n_s16(0);
6134-
6135-
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
6136-
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
6137-
uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
6138-
uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
6139-
uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
6140-
uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
6141-
uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
6142-
uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
6143-
uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
6144-
6145-
int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
6146-
int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
6147-
int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
6148-
int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
6149-
int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
6150-
int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
6151-
int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
6152-
int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
6153-
6154-
const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
6155-
const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
6156-
const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
6157-
const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
6158-
const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
6159-
const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
6160-
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
6161-
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
6162-
6044+
#else
61636045
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
61646046
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
61656047
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
@@ -6176,22 +6058,30 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
61766058
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
61776059
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
61786060
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
6061+
#endif
61796062
}
61806063

61816064
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
61826065
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
61836066

6067+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6068+
6069+
#if defined(__ARM_FEATURE_DOTPROD)
6070+
sumi0 = vaddq_s32(sumi0, sumi1);
6071+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
6072+
6073+
sumf += d * (float) vaddvq_s32(sumi0);
6074+
#else
61846075
sumi0 = vaddq_s16(sumi0, sumi1);
61856076
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
61866077

6187-
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6188-
61896078
sumf += d * (float) vaddlvq_s16(sumi0);
6079+
#endif
61906080
}
61916081

61926082
*s = sumf;
61936083

6194-
#elif defined __AVX2__
6084+
#elif defined(__AVX2__)
61956085
__m256 sumf = _mm256_setzero_ps();
61966086

61976087
for (int i = 0; i < nb; ++i) {

0 commit comments

Comments
 (0)