@@ -657,6 +657,35 @@ static inline __m128i packNibbles( __m256i bytes ) {
657
657
}
658
658
#endif //__loongarch_asx
659
659
660
+ void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) {
661
+ static const int qk = QK2_2;
662
+
663
+ assert(k % qk == 0);
664
+
665
+ const int nb = k / qk;
666
+
667
+ for (int i = 0; i < nb; i++) {
668
+
669
+ for (int j = 0; j < qk/4; ++j) {
670
+ int8_t x0 = (int8_t)x[i*qk + 0 + j];
671
+ int8_t x1 = (int8_t)x[i*qk + 1*qk/4 + j];
672
+ int8_t x2 = (int8_t)x[i*qk + 2*qk/4 + j];
673
+ int8_t x3 = (int8_t)x[i*qk + 3*qk/4 + j];
674
+
675
+ const uint8_t xi0 = x0 < 0 ? 1 : x0 == 0 ? 2 : 3;
676
+ const uint8_t xi1 = x1 < 0 ? 1 : x1 == 0 ? 2 : 3;
677
+ const uint8_t xi2 = x2 < 0 ? 1 : x2 == 0 ? 2 : 3;
678
+ const uint8_t xi3 = x3 < 0 ? 1 : x3 == 0 ? 2 : 3;
679
+
680
+ y[i].qs[j] = 0;
681
+ y[i].qs[j] |= (xi0 << 0);
682
+ y[i].qs[j] |= (xi1 << 2);
683
+ y[i].qs[j] |= (xi2 << 4);
684
+ y[i].qs[j] |= (xi3 << 6);
685
+ }
686
+ }
687
+ }
688
+
660
689
// reference implementation for deterministic creation of model files
661
690
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
662
691
static const int qk = QK4_0;
@@ -1512,6 +1541,26 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1512
1541
#endif
1513
1542
}
1514
1543
1544
+ void dequantize_row_q2_2(const block_q2_2 * restrict x, float * restrict y, int64_t k) {
1545
+ static const int qk = QK2_2;
1546
+
1547
+ assert(k % qk == 0);
1548
+
1549
+ const int nb = k / qk;
1550
+
1551
+ for (int i = 0; i < nb; i++) {
1552
+
1553
+ for (int j = 0; j < qk/4; ++j) {
1554
+ const int8_t q = x[i].qs[j];
1555
+
1556
+ y[i*qk + j + 0 ] = (float) (((q >> 0) & 3) - 2);
1557
+ y[i*qk + j + 1*qk/4] = (float) (((q >> 2) & 3) - 2);
1558
+ y[i*qk + j + 2*qk/4] = (float) (((q >> 4) & 3) - 2);
1559
+ y[i*qk + j + 3*qk/4] = (float) (((q >> 6) & 3) - 2);
1560
+ }
1561
+ }
1562
+ }
1563
+
1515
1564
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) {
1516
1565
static const int qk = QK4_0;
1517
1566
@@ -3876,82 +3925,18 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r
3876
3925
#if defined(__AVX2__)
3877
3926
__m256 acc = _mm256_setzero_ps();
3878
3927
3879
- int leftovers = nb % 2;
3880
-
3881
- for (int i = 0; i < nb - leftovers; i += 2) {
3882
-
3883
- const __m256 d0 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 0].d) );
3884
- const __m256 d1 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 1].d) );
3885
-
3886
- // assuming two consecutive blocks are contiguous AND aligned
3887
- __m128i xq16b = _mm_load_si128((const __m128i *) (x[i].qs));
3888
- __m256i xq16 = MM256_SET_M128I(xq16b, xq16b);
3889
- __m256i xq8l0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
3890
- 4, -1, 4, -1, 4, -1, 4, -1,
3891
- 1, -1, 1, -1, 1, -1, 1, -1,
3892
- 0, -1, 0, -1, 0, -1, 0, -1));
3893
- __m256i xq8h0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
3894
- 6, -1, 6, -1, 6, -1, 6, -1,
3895
- 3, -1, 3, -1, 3, -1, 3, -1,
3896
- 2, -1, 2, -1, 2, -1, 2, -1));
3897
- __m256i xq8l1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(13, -1, 13, -1, 13, -1, 13, -1,
3898
- 12, -1, 12, -1, 12, -1, 12, -1,
3899
- 9, -1, 9, -1, 9, -1, 9, -1,
3900
- 8, -1, 8, -1, 8, -1, 8, -1));
3901
- __m256i xq8h1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(15, -1, 15, -1, 15, -1, 15, -1,
3902
- 14, -1, 14, -1, 14, -1, 14, -1,
3903
- 11, -1, 11, -1, 11, -1, 11, -1,
3904
- 10, -1, 10, -1, 10, -1, 10, -1));
3905
- __m256i shift = _mm256_set_epi16(64, 16, 4, 1,
3906
- 64, 16, 4, 1,
3907
- 64, 16, 4, 1,
3908
- 64, 16, 4, 1);
3909
- xq8l0 = _mm256_mullo_epi16(xq8l0, shift);
3910
- xq8h0 = _mm256_mullo_epi16(xq8h0, shift);
3911
- xq8l1 = _mm256_mullo_epi16(xq8l1, shift);
3912
- xq8h1 = _mm256_mullo_epi16(xq8h1, shift);
3913
- xq8l0 = _mm256_srai_epi16(xq8l0, 14);
3914
- xq8h0 = _mm256_srai_epi16(xq8h0, 14);
3915
- xq8l1 = _mm256_srai_epi16(xq8l1, 14);
3916
- xq8h1 = _mm256_srai_epi16(xq8h1, 14);
3917
- __m256i xq8_0 = _mm256_packs_epi16(xq8l0, xq8h0);
3918
- __m256i xq8_1 = _mm256_packs_epi16(xq8l1, xq8h1);
3919
-
3920
- __m256i yq8_0 = _mm256_loadu_si256((const __m256i *) (y[i + 0].qs));
3921
- __m256i yq8_1 = _mm256_loadu_si256((const __m256i *) (y[i + 1].qs));
3922
-
3923
- const __m256 q0 = mul_sum_i8_pairs_float(xq8_0, yq8_0);
3924
- const __m256 q1 = mul_sum_i8_pairs_float(xq8_1, yq8_1);
3925
-
3926
- acc = _mm256_fmadd_ps( d0, q0, acc );
3927
- acc = _mm256_fmadd_ps( d1, q1, acc );
3928
- }
3929
-
3930
- for (int i = nb - leftovers; i < nb; ++i) {
3928
+ for (int i = 0; i < nb; ++i) {
3931
3929
3932
3930
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) );
3933
3931
3934
- __m128i xq8b = _mm_loadu_si64(x[i].qs);
3935
- __m256i xq8 = MM256_SET_M128I(xq8b, xq8b);
3936
- __m256i xq8l = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
3937
- 4, -1, 4, -1, 4, -1, 4, -1,
3938
- 1, -1, 1, -1, 1, -1, 1, -1,
3939
- 0, -1, 0, -1, 0, -1, 0, -1));
3940
- __m256i xq8h = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
3941
- 6, -1, 6, -1, 6, -1, 6, -1,
3942
- 3, -1, 3, -1, 3, -1, 3, -1,
3943
- 2, -1, 2, -1, 2, -1, 2, -1));
3944
- __m256i shift = _mm256_set_epi16(64, 16, 4, 1,
3945
- 64, 16, 4, 1,
3946
- 64, 16, 4, 1,
3947
- 64, 16, 4, 1);
3948
- xq8l = _mm256_mullo_epi16(xq8l, shift);
3949
- xq8h = _mm256_mullo_epi16(xq8h, shift);
3950
- xq8l = _mm256_srai_epi16(xq8l, 14);
3951
- xq8h = _mm256_srai_epi16(xq8h, 14);
3952
- xq8 = _mm256_packs_epi16(xq8l, xq8h);
3953
-
3954
- __m256i yq8 = _mm256_loadu_si256((const __m256i *) (y[i].qs));
3932
+ // assuming this is always aligned
3933
+ __m256i xq8 = _mm256_set1_epi64x(*(const int64_t *) x[i].qs);
3934
+ xq8 = _mm256_srlv_epi64(xq8, _mm256_set_epi64x(6, 4, 2, 0));
3935
+ xq8 = _mm256_and_si256(xq8, _mm256_set1_epi8(0x03));
3936
+ // stangely enough, this is much slower with 1 instead of 2
3937
+ xq8 = _mm256_sub_epi8(xq8, _mm256_set1_epi8(2));
3938
+
3939
+ const __m256i yq8 = _mm256_loadu_si256((const __m256i *) (y[i].qs));
3955
3940
const __m256 q = mul_sum_i8_pairs_float(xq8, yq8);
3956
3941
3957
3942
acc = _mm256_fmadd_ps( d, q, acc );
@@ -3964,11 +3949,11 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r
3964
3949
for (int i = 0; i < nb; i++) {
3965
3950
int sumi = 0;
3966
3951
for (int j = 0; j < qk / 4; j++) {
3967
- const int8_t* weight = (const int8_t *)(q22_grid + x[i].qs[j]) ;
3968
- sumi += (int)y[i].qs[4*j+0 ] * weight[0] ;
3969
- sumi += (int)y[i].qs[4*j+1 ] * weight[1] ;
3970
- sumi += (int)y[i].qs[4*j+2 ] * weight[2] ;
3971
- sumi += (int)y[i].qs[4*j+3 ] * weight[3] ;
3952
+ const uint8_t weight = x[i].qs[j];
3953
+ sumi += (int)y[i].qs[j + 0*qk/4 ] * (( weight >> 0) & 3) - 2 ;
3954
+ sumi += (int)y[i].qs[j + 1*qk/4 ] * (( weight >> 2) & 3) - 2 ;
3955
+ sumi += (int)y[i].qs[j + 2*qk/4 ] * (( weight >> 4) & 3) - 2 ;
3956
+ sumi += (int)y[i].qs[j + 3*qk/4 ] * (( weight >> 6) & 3) - 2 ;
3972
3957
}
3973
3958
sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d));
3974
3959
}
0 commit comments