@@ -669,21 +669,21 @@ void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict
669
669
for (int i = 0; i < nb; i++) {
670
670
671
671
for (int j = 0; j < qk/4; ++j) {
672
- int8_t x0 = (int8_t)x[i*qk + j*4 + 0 ];
673
- int8_t x1 = (int8_t)x[i*qk + j* 4 + 1 ];
674
- int8_t x2 = (int8_t)x[i*qk + j* 4 + 2 ];
675
- int8_t x3 = (int8_t)x[i*qk + j* 4 + 3 ];
672
+ int8_t x0 = (int8_t)x[i*qk + 0 + j ];
673
+ int8_t x1 = (int8_t)x[i*qk + 1*qk/ 4 + j ];
674
+ int8_t x2 = (int8_t)x[i*qk + 2*qk/ 4 + j ];
675
+ int8_t x3 = (int8_t)x[i*qk + 3*qk/ 4 + j ];
676
676
677
- const uint8_t xi0 = x0 >= 0 ? x0 : 3;
678
- const uint8_t xi1 = x1 >= 0 ? x1 : 3;
679
- const uint8_t xi2 = x2 >= 0 ? x2 : 3;
680
- const uint8_t xi3 = x3 >= 0 ? x3 : 3;
677
+ const uint8_t xi0 = x0 < 0 ? 1 : x0 == 0 ? 2 : 3;
678
+ const uint8_t xi1 = x1 < 0 ? 1 : x1 == 0 ? 2 : 3;
679
+ const uint8_t xi2 = x2 < 0 ? 1 : x2 == 0 ? 2 : 3;
680
+ const uint8_t xi3 = x3 < 0 ? 1 : x3 == 0 ? 2 : 3;
681
681
682
682
y[i].qs[j] = 0;
683
- y[i].qs[j] |= (xi0 << 6 );
684
- y[i].qs[j] |= (xi1 << 4 );
685
- y[i].qs[j] |= (xi2 << 2 );
686
- y[i].qs[j] |= (xi3 << 0 );
683
+ y[i].qs[j] |= (xi0 << 0 );
684
+ y[i].qs[j] |= (xi1 << 2 );
685
+ y[i].qs[j] |= (xi2 << 4 );
686
+ y[i].qs[j] |= (xi3 << 6 );
687
687
}
688
688
}
689
689
}
@@ -1555,12 +1555,12 @@ void dequantize_row_q2_2(const block_q2_2 * restrict x, float * restrict y, int6
1555
1555
for (int i = 0; i < nb; i++) {
1556
1556
1557
1557
for (int j = 0; j < qk/4; ++j) {
1558
- const int8_t * q = (const int8_t *) (q22_grid + x[i].qs[j]) ;
1558
+ const int8_t q = x[i].qs[j];
1559
1559
1560
- *y++ = (float) q[0] ;
1561
- *y++ = (float) q[1] ;
1562
- *y++ = (float) q[2] ;
1563
- *y++ = (float) q[3] ;
1560
+ y[i*qk + j + 0 ] = (float) (((q >> 0) & 3) - 2) ;
1561
+ y[i*qk + j + 1*qk/4] = (float) (((q >> 2) & 3) - 2) ;
1562
+ y[i*qk + j + 2*qk/4] = (float) (((q >> 4) & 3) - 2) ;
1563
+ y[i*qk + j + 3*qk/4] = (float) (((q >> 6) & 3) - 2) ;
1564
1564
}
1565
1565
}
1566
1566
}
@@ -3929,82 +3929,18 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r
3929
3929
#if defined(__AVX2__)
3930
3930
__m256 acc = _mm256_setzero_ps();
3931
3931
3932
- int leftovers = nb % 2;
3933
-
3934
- for (int i = 0; i < nb - leftovers; i += 2) {
3935
-
3936
- const __m256 d0 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 0].d) );
3937
- const __m256 d1 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 1].d) );
3938
-
3939
- // assuming two consecutive blocks are contiguous AND aligned
3940
- __m128i xq16b = _mm_load_si128((const __m128i *) (x[i].qs));
3941
- __m256i xq16 = MM256_SET_M128I(xq16b, xq16b);
3942
- __m256i xq8l0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
3943
- 4, -1, 4, -1, 4, -1, 4, -1,
3944
- 1, -1, 1, -1, 1, -1, 1, -1,
3945
- 0, -1, 0, -1, 0, -1, 0, -1));
3946
- __m256i xq8h0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
3947
- 6, -1, 6, -1, 6, -1, 6, -1,
3948
- 3, -1, 3, -1, 3, -1, 3, -1,
3949
- 2, -1, 2, -1, 2, -1, 2, -1));
3950
- __m256i xq8l1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(13, -1, 13, -1, 13, -1, 13, -1,
3951
- 12, -1, 12, -1, 12, -1, 12, -1,
3952
- 9, -1, 9, -1, 9, -1, 9, -1,
3953
- 8, -1, 8, -1, 8, -1, 8, -1));
3954
- __m256i xq8h1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(15, -1, 15, -1, 15, -1, 15, -1,
3955
- 14, -1, 14, -1, 14, -1, 14, -1,
3956
- 11, -1, 11, -1, 11, -1, 11, -1,
3957
- 10, -1, 10, -1, 10, -1, 10, -1));
3958
- __m256i shift = _mm256_set_epi16(64, 16, 4, 1,
3959
- 64, 16, 4, 1,
3960
- 64, 16, 4, 1,
3961
- 64, 16, 4, 1);
3962
- xq8l0 = _mm256_mullo_epi16(xq8l0, shift);
3963
- xq8h0 = _mm256_mullo_epi16(xq8h0, shift);
3964
- xq8l1 = _mm256_mullo_epi16(xq8l1, shift);
3965
- xq8h1 = _mm256_mullo_epi16(xq8h1, shift);
3966
- xq8l0 = _mm256_srai_epi16(xq8l0, 14);
3967
- xq8h0 = _mm256_srai_epi16(xq8h0, 14);
3968
- xq8l1 = _mm256_srai_epi16(xq8l1, 14);
3969
- xq8h1 = _mm256_srai_epi16(xq8h1, 14);
3970
- __m256i xq8_0 = _mm256_packs_epi16(xq8l0, xq8h0);
3971
- __m256i xq8_1 = _mm256_packs_epi16(xq8l1, xq8h1);
3972
-
3973
- __m256i yq8_0 = _mm256_loadu_si256((const __m256i *) (y[i + 0].qs));
3974
- __m256i yq8_1 = _mm256_loadu_si256((const __m256i *) (y[i + 1].qs));
3975
-
3976
- const __m256 q0 = mul_sum_i8_pairs_float(xq8_0, yq8_0);
3977
- const __m256 q1 = mul_sum_i8_pairs_float(xq8_1, yq8_1);
3978
-
3979
- acc = _mm256_fmadd_ps( d0, q0, acc );
3980
- acc = _mm256_fmadd_ps( d1, q1, acc );
3981
- }
3982
-
3983
- for (int i = nb - leftovers; i < nb; ++i) {
3932
+ for (int i = 0; i < nb; ++i) {
3984
3933
3985
3934
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) );
3986
3935
3987
- __m128i xq8b = _mm_loadu_si64(x[i].qs);
3988
- __m256i xq8 = MM256_SET_M128I(xq8b, xq8b);
3989
- __m256i xq8l = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
3990
- 4, -1, 4, -1, 4, -1, 4, -1,
3991
- 1, -1, 1, -1, 1, -1, 1, -1,
3992
- 0, -1, 0, -1, 0, -1, 0, -1));
3993
- __m256i xq8h = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
3994
- 6, -1, 6, -1, 6, -1, 6, -1,
3995
- 3, -1, 3, -1, 3, -1, 3, -1,
3996
- 2, -1, 2, -1, 2, -1, 2, -1));
3997
- __m256i shift = _mm256_set_epi16(64, 16, 4, 1,
3998
- 64, 16, 4, 1,
3999
- 64, 16, 4, 1,
4000
- 64, 16, 4, 1);
4001
- xq8l = _mm256_mullo_epi16(xq8l, shift);
4002
- xq8h = _mm256_mullo_epi16(xq8h, shift);
4003
- xq8l = _mm256_srai_epi16(xq8l, 14);
4004
- xq8h = _mm256_srai_epi16(xq8h, 14);
4005
- xq8 = _mm256_packs_epi16(xq8l, xq8h);
4006
-
4007
- __m256i yq8 = _mm256_loadu_si256((const __m256i *) (y[i].qs));
3936
+ // assuming this is always aligned
3937
+ __m256i xq8 = _mm256_set1_epi64x(*(const int64_t *) x[i].qs);
3938
+ xq8 = _mm256_srlv_epi64(xq8, _mm256_set_epi64x(6, 4, 2, 0));
3939
+ xq8 = _mm256_and_si256(xq8, _mm256_set1_epi8(0x03));
3940
+ // stangely enough, this is much slower with 1 instead of 2
3941
+ xq8 = _mm256_sub_epi8(xq8, _mm256_set1_epi8(2));
3942
+
3943
+ const __m256i yq8 = _mm256_loadu_si256((const __m256i *) (y[i].qs));
4008
3944
const __m256 q = mul_sum_i8_pairs_float(xq8, yq8);
4009
3945
4010
3946
acc = _mm256_fmadd_ps( d, q, acc );
@@ -4017,11 +3953,11 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r
4017
3953
for (int i = 0; i < nb; i++) {
4018
3954
int sumi = 0;
4019
3955
for (int j = 0; j < qk / 4; j++) {
4020
- const int8_t* weight = (const int8_t *)(q22_grid + x[i].qs[j]) ;
4021
- sumi += (int)y[i].qs[4*j+0 ] * weight[0] ;
4022
- sumi += (int)y[i].qs[4*j+1 ] * weight[1] ;
4023
- sumi += (int)y[i].qs[4*j+2 ] * weight[2] ;
4024
- sumi += (int)y[i].qs[4*j+3 ] * weight[3] ;
3956
+ const uint8_t weight = x[i].qs[j];
3957
+ sumi += (int)y[i].qs[j + 0*qk/4 ] * (( weight >> 0) & 3) - 2 ;
3958
+ sumi += (int)y[i].qs[j + 1*qk/4 ] * (( weight >> 2) & 3) - 2 ;
3959
+ sumi += (int)y[i].qs[j + 2*qk/4 ] * (( weight >> 4) & 3) - 2 ;
3960
+ sumi += (int)y[i].qs[j + 3*qk/4 ] * (( weight >> 6) & 3) - 2 ;
4025
3961
}
4026
3962
sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d));
4027
3963
}
0 commit comments