Skip to content

Commit 5d447e9

Browse files
committed
Update quantize_row_q4_0 for AVX/AVX2
1 parent da1bcb5 commit 5d447e9

File tree

1 file changed

+45
-22
lines changed

1 file changed

+45
-22
lines changed

ggml.c

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -642,22 +642,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
642642
__m256 v3 = _mm256_loadu_ps( x + 24 );
643643
x += 32;
644644

645-
// Compute max(abs(e)) for the block
646-
const __m256 signBit = _mm256_set1_ps( -0.0f );
647-
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
648-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
649-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
650-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
651-
652-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
645+
// Compute max for the block
646+
__m256 max = _mm256_max_ps( v0, v1 );
647+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
648+
max = _mm256_max_ps( max, maxTmp );
649+
650+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
653651
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
654652
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
655653
const float maxScalar = _mm_cvtss_f32( max4 );
656654

655+
// Compute min for the block
656+
__m256 min = _mm256_min_ps( v0, v1 );
657+
__m256 minTmp = _mm256_min_ps( v2, v3 );
658+
min = _mm256_min_ps( min, minTmp );
659+
660+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
661+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
662+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
663+
const float minScalar = _mm_cvtss_f32( min4 );
664+
657665
// Quantize these floats
658-
const float d = maxScalar / 7.0f;
666+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
667+
const float d = magnitude / -8.0f;
659668
y[i].d = d;
660-
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
669+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
661670
const __m256 mul = _mm256_set1_ps( id );
662671

663672
// Apply the multiplier
@@ -690,9 +699,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
690699
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
691700
i0 = _mm256_permutevar8x32_epi32( i0, perm );
692701

693-
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
702+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
694703
const __m256i off = _mm256_set1_epi8( 8 );
695704
i0 = _mm256_add_epi8( i0, off );
705+
const __m256i maxNibble = _mm256_set1_epi8( 15 );
706+
i0 = _mm256_min_epi8( i0, maxNibble );
696707

697708
// Compress the vector into 4 bit/value, and store
698709
__m128i res = packNibbles( i0 );
@@ -707,22 +718,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
707718
__m256 v3 = _mm256_loadu_ps( x + 24 );
708719
x += 32;
709720

710-
// Compute max(abs(e)) for the block
711-
const __m256 signBit = _mm256_set1_ps( -0.0f );
712-
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
713-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
714-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
715-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
721+
// Compute max for the block
722+
__m256 max = _mm256_max_ps( v0, v1 );
723+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
724+
max = _mm256_max_ps( max, maxTmp );
716725

717-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
726+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
718727
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
719728
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
720729
const float maxScalar = _mm_cvtss_f32( max4 );
721730

731+
// Compute min for the block
732+
__m256 min = _mm256_min_ps( v0, v1 );
733+
__m256 minTmp = _mm256_min_ps( v2, v3 );
734+
min = _mm256_min_ps( min, minTmp );
735+
736+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
737+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
738+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
739+
const float minScalar = _mm_cvtss_f32( min4 );
740+
722741
// Quantize these floats
723-
const float d = maxScalar / 7.0f;
742+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
743+
const float d = magnitude / -8.0f;
724744
y[i].d = d;
725-
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
745+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
726746
const __m256 mul = _mm256_set1_ps( id );
727747

728748
// Apply the multiplier
@@ -763,10 +783,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
763783
ni0 = _mm_packs_epi16( ni0, ni2 );
764784
ni4 = _mm_packs_epi16( ni4, ni6 );
765785

766-
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
767-
const __m128i off = _mm_set1_epi8( 8);
786+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
787+
const __m128i off = _mm_set1_epi8( 8 );
768788
ni0 = _mm_add_epi8( ni0, off );
769789
ni4 = _mm_add_epi8( ni4, off );
790+
const __m128i maxNibble = _mm_set1_epi8( 15 );
791+
ni0 = _mm_min_epi8( ni0, maxNibble );
792+
ni4 = _mm_min_epi8( ni4, maxNibble );
770793

771794
// Compress the vector into 4 bit/value, and store
772795
__m128i res = packNibbles( ni0, ni4 );

0 commit comments

Comments
 (0)