Skip to content

Commit 459e93c

Browse files
authored
Add AVX2 implementation of dequantize_row_q4_1 (#505)
1 parent a316a42 commit 459e93c

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

ggml.c

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
783783

784784
// Scale and store
785785
for (int j = 0; j < 4; j++) {
786-
__m256 result = _mm256_mul_ps(vf[j], d_v);
786+
const __m256 result = _mm256_mul_ps(vf[j], d_v);
787787
_mm256_storeu_ps(y + i * QK + l + j*8, result);
788788
}
789789
}
@@ -879,6 +879,37 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
879879
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
880880
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
881881

882+
#if defined(__AVX2__)
883+
for (int i = 0; i < nb; i++) {
884+
const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
885+
const __m256 d_m = _mm256_broadcast_ss((const float *) (pm + i*bs));
886+
887+
const uint8_t * restrict pp = pb + i*bs;
888+
889+
for (int l = 0; l < QK; l += 32) {
890+
// Load 32x4-bit integers into 32x8-bit integers
891+
__m256i vx8 = bytesFromNibbles(pp+l/2);
892+
893+
// Convert to 16-bit int
894+
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
895+
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
896+
897+
// Convert to 32-bit int -> float 32
898+
const __m256 vf[4] = {
899+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
900+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
901+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
902+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
903+
};
904+
905+
// Scale, add m and store
906+
for (int j = 0; j < 4; j++) {
907+
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
908+
_mm256_storeu_ps(y + i * QK + l + j*8, result);
909+
}
910+
}
911+
}
912+
#else
882913
for (int i = 0; i < nb; i++) {
883914
const float d = *(const float *) (pd + i*bs);
884915
const float m = *(const float *) (pm + i*bs);
@@ -901,6 +932,7 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
901932
assert(!isnan(y[i*QK + l + 1]));
902933
}
903934
}
935+
#endif
904936
}
905937

906938
//

0 commit comments

Comments
 (0)