@@ -783,7 +783,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
783
783
784
784
// Scale and store
785
785
for (int j = 0 ; j < 4 ; j ++ ) {
786
- __m256 result = _mm256_mul_ps (vf [j ], d_v );
786
+ const __m256 result = _mm256_mul_ps (vf [j ], d_v );
787
787
_mm256_storeu_ps (y + i * QK + l + j * 8 , result );
788
788
}
789
789
}
@@ -879,6 +879,37 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
879
879
const uint8_t * restrict pm = ((const uint8_t * )x + 0 * bs + sizeof (float ));
880
880
const uint8_t * restrict pb = ((const uint8_t * )x + 0 * bs + 2 * sizeof (float ));
881
881
882
+ #if defined(__AVX2__ )
883
+ for (int i = 0 ; i < nb ; i ++ ) {
884
+ const __m256 d_v = _mm256_broadcast_ss ((const float * ) (pd + i * bs ));
885
+ const __m256 d_m = _mm256_broadcast_ss ((const float * ) (pm + i * bs ));
886
+
887
+ const uint8_t * restrict pp = pb + i * bs ;
888
+
889
+ for (int l = 0 ; l < QK ; l += 32 ) {
890
+ // Load 32x4-bit integers into 32x8-bit integers
891
+ __m256i vx8 = bytesFromNibbles (pp + l /2 );
892
+
893
+ // Convert to 16-bit int
894
+ const __m256i vx16_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (vx8 , 0 ));
895
+ const __m256i vx16_hi = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (vx8 , 1 ));
896
+
897
+ // Convert to 32-bit int -> float 32
898
+ const __m256 vf [4 ] = {
899
+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_lo , 0 ))),
900
+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_lo , 1 ))),
901
+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_hi , 0 ))),
902
+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_hi , 1 )))
903
+ };
904
+
905
+ // Scale, add m and store
906
+ for (int j = 0 ; j < 4 ; j ++ ) {
907
+ const __m256 result = _mm256_add_ps (_mm256_mul_ps (vf [j ], d_v ), d_m );
908
+ _mm256_storeu_ps (y + i * QK + l + j * 8 , result );
909
+ }
910
+ }
911
+ }
912
+ #else
882
913
for (int i = 0 ; i < nb ; i ++ ) {
883
914
const float d = * (const float * ) (pd + i * bs );
884
915
const float m = * (const float * ) (pm + i * bs );
@@ -901,6 +932,7 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
901
932
assert (!isnan (y [i * QK + l + 1 ]));
902
933
}
903
934
}
935
+ #endif
904
936
}
905
937
906
938
//
0 commit comments