@@ -660,97 +660,148 @@ static inline int compare_fractions_desc(const void * a, const void * b) {
660
660
661
661
// exhaustive search with cumulative sums
662
662
// Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions
663
- static float make_qkxs_quants (int n , int nmin , int nmax , const float * restrict x , const float * restrict weights , int8_t * restrict L , struct fraction * restrict Faux , bool signed_scale ) {
664
- float max = 0.0f ;
665
- float amax = 0.0f ;
666
- for (int i = 0 ; i < n ; ++ i ) {
667
- float ax = fabsf (x [i ]);
668
- if (ax > amax ) {
669
- amax = ax ;
670
- max = x [i ];
663
+ static float make_qkxs_quants (int n , int nmin , int nmax , const float * restrict x , const float * restrict weights , int8_t * restrict L , int8_t * restrict Laux , struct fraction * restrict Faux , bool signed_scale ) {
664
+ const int orig_nmin = nmin ;
665
+ const int orig_nmax = nmax ;
666
+ float max = x [0 ];
667
+ float min = x [0 ];
668
+ float w_amax = weights [0 ] * fabsf (x [0 ]);
669
+ int max_i = 0 ;
670
+ int w_amax_i = 0 ;
671
+ int min_i = 0 ;
672
+ for (int i = 1 ; i < n ; ++ i ) {
673
+ if (x [i ] < min ) { min = x [i ]; min_i = i ; }
674
+ if (x [i ] > max ) { max = x [i ]; max_i = i ; }
675
+ // Find the most important value
676
+ const float w = weights [i ];
677
+ const float wax = w * fabsf (x [i ]);
678
+ if (wax > w_amax ) {
679
+ w_amax = wax ;
680
+ w_amax_i = i ;
681
+ }
682
+ }
683
+ const int amax_i = fabsf (min ) > fabsf (max ) ? min_i : max_i ;
684
+ const float amax = fabsf (x [amax_i ]);
685
+
686
+ if (amax < GROUP_MAX_EPS ) { // all zero
687
+ for (int i = 0 ; i < n ; ++ i ) {
688
+ L [i ] = 0 ;
671
689
}
690
+ return 0.0f ;
672
691
}
673
692
bool negative_scale = false;
674
693
if (signed_scale && - nmin != nmax ) {
675
694
// the max side should have the biggest range
676
- if ((max < 0.0f ) == (- nmin < nmax )) {
695
+ // FIXME: this is incorrect when the weights[.] do not sort in the same order as fabsf(x[.])
696
+ // or is it some other condition?
697
+ if ((x [amax_i ] < 0.0f ) == (- nmin < nmax )) {
677
698
// [-4, 3] ==> [-3, 4]
678
- int tmp = nmin ;
699
+ const int tmp = nmin ;
700
+ const float ftmp = min ;
679
701
nmin = - nmax ;
680
702
nmax = - tmp ;
703
+ min = - max ;
704
+ max = - ftmp ;
681
705
negative_scale = true;
682
706
}
683
707
}
684
- if (amax < GROUP_MAX_EPS ) { // all zero
708
+
709
+ // Find the max range in [0, amax_range] which doesn't result in clamping.
710
+ // This is the range from the side which would clamp first (biggest ratio of max to nmax).
711
+ int amax_range ;
712
+ float range_max ;
713
+ if (fabsf (- max * nmin ) < fabsf (- min * nmax )) {
714
+ amax_range = MAX (0 , - nmin );
715
+ range_max = fabsf (min );
716
+ } else {
717
+ amax_range = MAX (0 , nmax );
718
+ range_max = fabsf (max );
719
+ }
720
+ float sumlx = 0.0f ;
721
+ float suml2 = 0.0f ;
722
+ float scale = 0.0f ;
723
+ float best = 0.0f ;
724
+ float best_denom = 1.0f ;
725
+ if (amax_range > 1 ) {
726
+ // The smallest non-redundant iscale makes the first clamped value half+1 its max integer value.
727
+ // Proof: anything smaller has a representable vector with values twice as big.
728
+ const float iscale = ((float )(amax_range / 2 + 1 ))/range_max * (negative_scale ? -1.0f : 1.0f );
729
+ for (int i = 0 ; i < n ; ++ i ) {
730
+ const float w = weights [i ];
731
+ int l = MAX (nmin , MIN (lroundf (x [i ] * iscale ), nmax ));
732
+ if (negative_scale ) { l = - l ; }
733
+ Laux [i ] = l ;
734
+ L [i ] = l ;
735
+ suml2 += w * l * l ;
736
+ sumlx += w * l * x [i ];
737
+ }
738
+ best = sumlx * sumlx ;
739
+ best_denom = suml2 ; // should never be zero
740
+ scale = sumlx / suml2 ;
741
+ } else {
685
742
for (int i = 0 ; i < n ; ++ i ) {
743
+ Laux [i ] = 0 ;
686
744
L [i ] = 0 ;
687
745
}
688
- return 0.0f ;
689
746
}
747
+
748
+ const int imax_range = MAX (0 , (x [w_amax_i ] < 0.0f ) ? - nmin : nmax );
749
+ const int max_odd = 2 * (imax_range + 1 ) + 1 ;
750
+ const float wmax = fabsf (x [w_amax_i ]);
690
751
int n_frac = 0 ;
691
752
for (int i = 0 ; i < n ; ++ i ) {
692
753
// assuming nmin <= nmax
693
- const int odd_max = MAX (0 , x [i ] < 0 ? - nmin : nmax );
694
- const int odd_min = MAX (0 , x [i ] < 0 ? - nmax : nmin );
754
+ const int odd_max = MAX (abs ( Laux [ i ]) , x [i ] < 0.0f ? - nmin : nmax );
755
+ const int odd_min = MAX (abs ( Laux [ i ]) , x [i ] < 0.0f ? - nmax : nmin );
695
756
const float v = fabsf (x [i ]);
696
- // fprintf(stderr, "%s: i=%d, odd_min=%d, odd_max=%d\n", __func__, i, odd_min, odd_max) ;
757
+ const float v_max_odd = v * max_odd ;
697
758
for (int j = odd_min ; j < odd_max ; ++ j ) {
698
759
const float odd = 2 * j + 1 ;
699
- Faux [n_frac ++ ] = (struct fraction ){
700
- .numer = v ,
701
- .denom = odd ,
702
- .i = i ,
703
- };
760
+ if (wmax * odd < v_max_odd ) {
761
+ Faux [n_frac ++ ] = (struct fraction ){
762
+ .numer = v ,
763
+ .denom = odd ,
764
+ .i = i ,
765
+ };
766
+ } else {
767
+ // stop when the inverse scale would result in clamping the max (FIXME: most important) value
768
+ break ;
769
+ }
704
770
}
705
771
}
706
772
707
773
qsort (Faux , n_frac , sizeof (struct fraction ), compare_fractions_desc );
708
774
709
- float iscale = 0.0f ;
710
- {
711
- float sumlx = 0.0f ;
712
- float suml2 = 0.0f ;
713
- float best = 0.0f ;
714
- float best_denom = 1.0f ;
715
- for (int i = 0 ; i < n_frac ; ++ i ) {
716
- // maximize the weighted cosine
717
- const int ii = Faux [i ].i ;
718
- const float w = weights ? weights [ii ] : x [ii ] * x [ii ];
719
- sumlx += w * Faux [i ].numer ;
720
- suml2 += w * Faux [i ].denom ;
721
- const float current = sumlx * sumlx ;
722
- // fprintf(stderr, "%s: Faux[%d]=(%f/%f) * %f, square(sumlx)=%f, suml2=%f, k*cos2=%f\n", __func__, i, Faux[i].numer, Faux[i].denom, Faux[i].weight, current, suml2, current / suml2);
723
- // use the last in case of equality
724
- // FIXME: > or >= ?? Why does [0, 0, 1] rounds to [0, 0, 0] with >= ?
725
- if (suml2 > 0.0f && current * best_denom > best * suml2 ) {
726
- best = current ;
727
- best_denom = suml2 ;
728
- iscale = Faux [i ].numer > 0.0f ? Faux [i ].denom / (2.0f * Faux [i ].numer ) : 0.0f ;
729
- if (!isfinite (iscale )) {
730
- fprintf (stderr , "%s: iscale is not finite, %f/(2*%f)\n" , __func__ , Faux [i ].denom , Faux [i ].numer );
775
+ int best_p_i = -1 ; // consecutive with 0..n_frac
776
+ for (int i = 0 ; i < n_frac ; ++ i ) {
777
+ // maximize the weighted cosine
778
+ const int ii = Faux [i ].i ;
779
+ const float w = weights ? weights [ii ] : x [ii ] * x [ii ];
780
+ sumlx += w * Faux [i ].numer ;
781
+ suml2 += w * Faux [i ].denom ;
782
+ const float current = sumlx * sumlx ;
783
+ Laux [ii ] += x [ii ] < 0.0f ? -1 : 1 ;
784
+ if (suml2 > 0.0f && Faux [i ].numer > 0.0f && current * best_denom > best * suml2 ) {
785
+ best = current ;
786
+ best_denom = suml2 ;
787
+ scale = sumlx / suml2 ;
788
+ if (i == best_p_i + 1 ) {
789
+ // reduce copies for consecutive bests
790
+ L [ii ] += x [ii ] < 0.0f ? -1 : 1 ;
791
+ } else {
792
+ for (int j = 0 ; j < n ; ++ j ) {
793
+ L [j ] = Laux [j ];
731
794
}
732
795
}
796
+ best_p_i = i ;
733
797
}
734
798
}
735
- // (very) small fudging necessary because floats otherwise round to nearest even
736
- iscale = iscale * ((float )((1 << 23 ) + 1 ) / (float )(1 << 23 ));
737
-
738
- float sumlx = 0.0f ;
739
- float suml2 = 0.0f ;
740
799
for (int i = 0 ; i < n ; ++ i ) {
741
- // Rounding away from zero is assumed by the search algorithm above.
742
- int l = MAX (nmin , MIN (lroundf (x [i ] * iscale ), nmax ));
743
- if (negative_scale ) {
744
- l = - l ;
745
- }
746
- L [i ] = negative_scale ? l + nmax : l - nmin ;
747
- float w = weights ? weights [i ] : x [i ] * x [i ];
748
- // weighted projection scale
749
- sumlx += w * x [i ] * l ;
750
- suml2 += w * l * l ;
800
+ L [i ] = negative_scale ? (- L [i ] + nmax ) : (L [i ] + - nmin );
801
+ GGML_ASSERT (L [i ] >= 0 && L [i ] <= nmax - nmin );
751
802
}
752
803
753
- return suml2 > 0.0f ? sumlx / suml2 : 0.0f ;
804
+ return negative_scale ? - scale : scale ;
754
805
}
755
806
756
807
// non-linear exhaustive search with cumulative sums
@@ -1234,6 +1285,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
1234
1285
const int nb = k / QK_K ;
1235
1286
1236
1287
int8_t L [QK_K ];
1288
+ int8_t Laux [16 ];
1237
1289
struct fraction Faux [16 * 4 ];
1238
1290
float scales [QK_K / 16 ];
1239
1291
float weights [16 ];
@@ -1247,7 +1299,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
1247
1299
float max_scale = 0 ;
1248
1300
float amax = 0 ;
1249
1301
for (int j = 0 ; j < QK_K /16 ; ++ j ) {
1250
- scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weights , L + 16 * j , Faux , true);
1302
+ scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weights , L + 16 * j , Laux , Faux , true);
1251
1303
// scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
1252
1304
float scale = fabsf (scales [j ]);
1253
1305
if (scale > amax ) {
@@ -1367,6 +1419,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
1367
1419
const int nb = n_per_row / QK_K ;
1368
1420
1369
1421
int8_t L [QK_K ];
1422
+ int8_t Laux [16 ];
1370
1423
float scales [QK_K / 16 ];
1371
1424
float weight [16 ];
1372
1425
float sw [QK_K / 16 ];
@@ -1391,14 +1444,14 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
1391
1444
sw [j ] = sumw ;
1392
1445
1393
1446
// scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
1394
- scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weight , L + 16 * j , Faux , true);
1447
+ scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weight , L + 16 * j , Laux , Faux , true);
1395
1448
1396
1449
}
1397
1450
1398
1451
memset (y [i ].scales , 0 , 12 );
1399
1452
1400
1453
// float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
1401
- float d_block = make_qkxs_quants (QK_K /16 , -32 , 31 , scales , sw , Ls , Faux , true);
1454
+ float d_block = make_qkxs_quants (QK_K /16 , -32 , 31 , scales , sw , Ls , Laux , Faux , true);
1402
1455
for (int j = 0 ; j < QK_K /16 ; ++ j ) {
1403
1456
int l = Ls [j ];
1404
1457
if (j < 8 ) {
@@ -4856,11 +4909,11 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
4856
4909
for (int j = 0 ; j < block_size ; ++ j ) weight [j ] = sqrtf (sigma2 + xb [j ]* xb [j ]);
4857
4910
// for (int j = 0; j < block_size; ++j) weight[j] = 1;
4858
4911
}
4859
- float amax = 0 , max = 0 ;
4912
+ float amax = 0 ;
4860
4913
for (int j = 0 ; j < block_size ; ++ j ) {
4861
4914
float ax = fabsf (xb [j ]);
4862
4915
if (ax > amax ) {
4863
- amax = ax ; max = xb [ j ];
4916
+ amax = ax ;
4864
4917
}
4865
4918
}
4866
4919
if (amax < GROUP_MAX_EPS ) {
0 commit comments