@@ -642,22 +642,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
642
642
__m256 v3 = _mm256_loadu_ps ( x + 24 );
643
643
x += 32 ;
644
644
645
- // Compute max(abs(e)) for the block
646
- const __m256 signBit = _mm256_set1_ps ( -0.0f );
647
- __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
648
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
649
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
650
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
651
-
652
- __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
645
+ // Compute max for the block
646
+ __m256 max = _mm256_max_ps ( v0 , v1 );
647
+ __m256 maxTmp = _mm256_max_ps ( v2 , v3 );
648
+ max = _mm256_max_ps ( max , maxTmp );
649
+
650
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( max , 1 ), _mm256_castps256_ps128 ( max ) );
653
651
max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
654
652
max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
655
653
const float maxScalar = _mm_cvtss_f32 ( max4 );
656
654
655
+ // Compute min for the block
656
+ __m256 min = _mm256_min_ps ( v0 , v1 );
657
+ __m256 minTmp = _mm256_min_ps ( v2 , v3 );
658
+ min = _mm256_min_ps ( min , minTmp );
659
+
660
+ __m128 min4 = _mm_min_ps ( _mm256_extractf128_ps ( min , 1 ), _mm256_castps256_ps128 ( min ) );
661
+ min4 = _mm_min_ps ( min4 , _mm_movehl_ps ( min4 , min4 ) );
662
+ min4 = _mm_min_ss ( min4 , _mm_movehdup_ps ( min4 ) );
663
+ const float minScalar = _mm_cvtss_f32 ( min4 );
664
+
657
665
// Quantize these floats
658
- const float d = maxScalar / 7.0f ;
666
+ const float magnitude = maxScalar >= fabsf (minScalar ) ? maxScalar : minScalar ;
667
+ const float d = magnitude / -8.0f ;
659
668
y [i ].d = d ;
660
- const float id = ( maxScalar != 0.0f ) ? 7 .0f / maxScalar : 0.0f ;
669
+ const float id = ( magnitude != 0.0f ) ? -8 .0f / magnitude : 0.0f ;
661
670
const __m256 mul = _mm256_set1_ps ( id );
662
671
663
672
// Apply the multiplier
@@ -690,9 +699,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
690
699
const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
691
700
i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
692
701
693
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
702
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
694
703
const __m256i off = _mm256_set1_epi8 ( 8 );
695
704
i0 = _mm256_add_epi8 ( i0 , off );
705
+ const __m256i maxNibble = _mm256_set1_epi8 ( 15 );
706
+ i0 = _mm256_min_epi8 ( i0 , maxNibble );
696
707
697
708
// Compress the vector into 4 bit/value, and store
698
709
__m128i res = packNibbles ( i0 );
@@ -707,22 +718,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
707
718
__m256 v3 = _mm256_loadu_ps ( x + 24 );
708
719
x += 32 ;
709
720
710
- // Compute max(abs(e)) for the block
711
- const __m256 signBit = _mm256_set1_ps ( -0.0f );
712
- __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
713
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
714
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
715
- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
721
+ // Compute max for the block
722
+ __m256 max = _mm256_max_ps ( v0 , v1 );
723
+ __m256 maxTmp = _mm256_max_ps ( v2 , v3 );
724
+ max = _mm256_max_ps ( max , maxTmp );
716
725
717
- __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
726
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( max , 1 ), _mm256_castps256_ps128 ( max ) );
718
727
max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
719
728
max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
720
729
const float maxScalar = _mm_cvtss_f32 ( max4 );
721
730
731
+ // Compute min for the block
732
+ __m256 min = _mm256_min_ps ( v0 , v1 );
733
+ __m256 minTmp = _mm256_min_ps ( v2 , v3 );
734
+ min = _mm256_min_ps ( min , minTmp );
735
+
736
+ __m128 min4 = _mm_min_ps ( _mm256_extractf128_ps ( min , 1 ), _mm256_castps256_ps128 ( min ) );
737
+ min4 = _mm_min_ps ( min4 , _mm_movehl_ps ( min4 , min4 ) );
738
+ min4 = _mm_min_ss ( min4 , _mm_movehdup_ps ( min4 ) );
739
+ const float minScalar = _mm_cvtss_f32 ( min4 );
740
+
722
741
// Quantize these floats
723
- const float d = maxScalar / 7.0f ;
742
+ const float magnitude = maxScalar >= fabsf (minScalar ) ? maxScalar : minScalar ;
743
+ const float d = magnitude / -8.0f ;
724
744
y [i ].d = d ;
725
- const float id = ( maxScalar != 0.0f ) ? 7 .0f / maxScalar : 0.0f ;
745
+ const float id = ( magnitude != 0.0f ) ? -8 .0f / magnitude : 0.0f ;
726
746
const __m256 mul = _mm256_set1_ps ( id );
727
747
728
748
// Apply the multiplier
@@ -763,10 +783,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
763
783
ni0 = _mm_packs_epi16 ( ni0 , ni2 );
764
784
ni4 = _mm_packs_epi16 ( ni4 , ni6 );
765
785
766
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
767
- const __m128i off = _mm_set1_epi8 ( 8 );
786
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
787
+ const __m128i off = _mm_set1_epi8 ( 8 );
768
788
ni0 = _mm_add_epi8 ( ni0 , off );
769
789
ni4 = _mm_add_epi8 ( ni4 , off );
790
+ const __m128i maxNibble = _mm_set1_epi8 ( 15 );
791
+ ni0 = _mm_min_epi8 ( ni0 , maxNibble );
792
+ ni4 = _mm_min_epi8 ( ni4 , maxNibble );
770
793
771
794
// Compress the vector into 4 bit/value, and store
772
795
__m128i res = packNibbles ( ni0 , ni4 );
0 commit comments