@@ -687,11 +687,11 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
687
687
}
688
688
return 0.0f ;
689
689
}
690
+
690
691
bool negative_scale = false;
691
692
if (signed_scale && - nmin != nmax ) {
692
693
// the max side should have the biggest range
693
- // FIXME: this is incorrect when the weights[.] do not sort in the same order as fabsf(x[.])
694
- // or is it some other condition?
694
+ // FIXME: this is not always the best sign
695
695
if ((x [amax_i ] < 0.0f ) == (- nmin < nmax )) {
696
696
// [-4, 3] ==> [-3, 4]
697
697
const int tmp = nmin ;
@@ -762,7 +762,7 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
762
762
.i = i ,
763
763
};
764
764
} else {
765
- // stop when the inverse scale would result in clamping the max (FIXME: most important) value
765
+ // stop when the inverse scale would result in clamping the most important value
766
766
break ;
767
767
}
768
768
}
@@ -802,6 +802,182 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
802
802
return negative_scale ? - scale : scale ;
803
803
}
804
804
805
+ // Very similar to make_qkxs_quants, but the sign of the scale is not assumed to be the sign of the absmax value.
806
+ static float make_qkxss_quants (int n , int nmin , int nmax , const float * restrict x , const float * restrict weights , int8_t * restrict L , int8_t * restrict Laux , struct fraction * restrict Faux ) {
807
+ // start at zero
808
+ nmin = MIN (0 , nmin );
809
+ nmax = MAX (0 , nmax );
810
+ float amax = 0.0f ;
811
+ float min = 0.0f ;
812
+ float max = 0.0f ;
813
+ float w_amax = 0.0f ;
814
+ int amax_i = -1 ;
815
+ int w_amax_i = -1 ;
816
+ for (int i = 0 ; i < n ; ++ i ) {
817
+ const float w = weights ? weights [i ] : x [i ] * x [i ];
818
+ const float ax = fabsf (x [i ]);
819
+ const float wax = w * ax ;
820
+ if (ax > amax ) { amax = ax ; amax_i = i ; }
821
+ if (x [i ] > max ) { max = x [i ]; }
822
+ if (x [i ] < min ) { min = x [i ]; }
823
+ // Find the most important value
824
+ if (wax > w_amax ) { w_amax = wax ; w_amax_i = i ; }
825
+ }
826
+
827
+ if (amax < GROUP_MAX_EPS || amax_i < 0 || w_amax_i < 0 ) { // all zero
828
+ for (int i = 0 ; i < n ; ++ i ) { L [i ] = 0 ; }
829
+ return 0.0f ;
830
+ }
831
+
832
+ // Use the side which will clamp first.
833
+ // The first clamped value is the absmax at the end of the common range.
834
+ // TODO: reduce the search space when one of the ranges is 0
835
+ const int amax_range = MIN (- nmin , nmax );
836
+ float sumlx_p = 0.0f ;
837
+ float suml2_p = 0.0f ;
838
+ float sumlx_n = 0.0f ;
839
+ float suml2_n = 0.0f ;
840
+ float scale = 0.0f ;
841
+ float best = 0.0f ;
842
+ float best_denom = 1.0f ;
843
+ int best_i = -2 ; // not consecutive with 0..n_frac
844
+ // Pre-calculate the half-point for the common range.
845
+ // All smaller vectors have a representable vector with twice the values, and thus can be skipped.
846
+ if (amax_range > 1 ) {
847
+ const float iscale = ((float )(amax_range / 2 + 1 ))/amax ;
848
+ for (int i = 0 ; i < n ; ++ i ) {
849
+ const float w = weights ? weights [i ] : x [i ] * x [i ];
850
+ int l = MAX (nmin , MIN (lroundf (x [i ] * iscale ), nmax ));
851
+ Laux [i ] = l ;
852
+ suml2_p += w * l * l ;
853
+ sumlx_p += w * l * x [i ];
854
+ }
855
+ sumlx_n = - sumlx_p ;
856
+ suml2_n = suml2_p ;
857
+ const float current_p = sumlx_p * sumlx_p ;
858
+ if (suml2_p > 0.0f && current_p * best_denom > best * suml2_p ) {
859
+ best = current_p ;
860
+ best_denom = suml2_p ;
861
+ scale = sumlx_p / suml2_p ;
862
+ for (int i = 0 ; i < n ; ++ i ) {
863
+ L [i ] = Laux [i ];
864
+ }
865
+ best_i = -1 ; // right before 0 of the loop after sorting
866
+ }
867
+ } else {
868
+ for (int i = 0 ; i < n ; ++ i ) {
869
+ Laux [i ] = 0 ;
870
+ }
871
+ }
872
+
873
+ const int imax_range = MAX (nmax , - nmin );
874
+ const int max_odd = 2 * (imax_range + 1 ) + 1 ;
875
+ const float wmax = fabsf (x [w_amax_i ]);
876
+ int n_frac = 0 ;
877
+ for (int i = 0 ; i < n ; ++ i ) {
878
+ // assuming nmin <= nmax
879
+ const int odd_max = MAX (nmax , - nmin );
880
+ const float v = fabsf (x [i ]);
881
+ const float v_max_odd = v * max_odd ;
882
+ for (int j = abs (Laux [i ]); j < odd_max ; ++ j ) {
883
+ const float odd = 2 * j + 1 ;
884
+ const float wmax_odd = wmax * odd ;
885
+ if (wmax_odd < v_max_odd ) {
886
+ Faux [n_frac ++ ] = (struct fraction ){
887
+ .numer = v ,
888
+ .denom = odd ,
889
+ .i = i ,
890
+ };
891
+ } else {
892
+ // stop when the inverse scale would result in clamping the most important value
893
+ break ;
894
+ }
895
+ }
896
+ }
897
+
898
+ qsort (Faux , n_frac , sizeof (struct fraction ), compare_fractions_desc );
899
+
900
+ const float max_common_odd = (MIN (nmax , - nmin ) * 2 ) + 1 ;
901
+ const float max_odd_p = (nmax * 2 ) + 1 ;
902
+ const float max_odd_n = (- nmin * 2 ) + 1 ;
903
+
904
+ for (int i = 0 ; i < n_frac ; ++ i ) {
905
+ // maximize the weighted cosine similarity
906
+ const int ii = Faux [i ].i ;
907
+ const float w = weights ? weights [ii ] : x [ii ] * x [ii ];
908
+ const float lx = w * Faux [i ].numer ;
909
+ const float odd = Faux [i ].denom ;
910
+ const float l2 = w * odd ;
911
+
912
+ Laux [ii ] += x [ii ] < 0.0f ? -1 : 1 ;
913
+
914
+ float sumlx = 0.0f ;
915
+ float proj = 0.0f ;
916
+ float norm = 0.0f ;
917
+ if (odd < max_common_odd ) {
918
+ sumlx_p += lx ;
919
+ suml2_p += l2 ;
920
+ sumlx_n -= lx ;
921
+ suml2_n += l2 ;
922
+
923
+ sumlx = sumlx_p ;
924
+ proj = sumlx_p * sumlx_p ;
925
+ norm = suml2_p ;
926
+
927
+ // avoid double-copying Laux in a single iteration
928
+ if (suml2_p != suml2_n && suml2_p * suml2_n > 0.0f ) {
929
+ const float proj_n = sumlx_n * sumlx_n ;
930
+ if (proj_n * norm > proj * suml2_n ) {
931
+ sumlx = sumlx_n ;
932
+ proj = proj_n ;
933
+ norm = suml2_n ;
934
+ }
935
+ }
936
+ } else if (x [ii ] < 0.0f ? odd < max_odd_n : odd < max_odd_p ) {
937
+ sumlx_p += lx ;
938
+ suml2_p += l2 ;
939
+
940
+ sumlx = sumlx_p ;
941
+ proj = sumlx_p * sumlx_p ;
942
+ norm = suml2_p ;
943
+ } else {
944
+ // outside the positive range means we're now into negatives
945
+ sumlx_n -= lx ;
946
+ suml2_n += l2 ;
947
+
948
+ sumlx = sumlx_n ;
949
+ proj = sumlx_n * sumlx_n ;
950
+ norm = suml2_n ;
951
+ }
952
+ if (norm > 0.0f && proj * best_denom > best * norm ) {
953
+ best = proj ;
954
+ best_denom = norm ;
955
+ scale = sumlx / norm ;
956
+ if (i == best_i + 1 ) {
957
+ // reduce copies for consecutive bests
958
+ L [ii ] += x [ii ] < 0.0f ? -1 : 1 ;
959
+ } else {
960
+ for (int j = 0 ; j < n ; ++ j ) {
961
+ L [j ] = Laux [j ];
962
+ }
963
+ }
964
+ best_i = i ;
965
+ }
966
+ }
967
+
968
+ if (scale < 0.0f ) {
969
+ for (int i = 0 ; i < n ; ++ i ) {
970
+ L [i ] = MAX (nmin , MIN (- L [i ], nmax )) - nmin ;
971
+ }
972
+ } else {
973
+ for (int i = 0 ; i < n ; ++ i ) {
974
+ L [i ] = MAX (nmin , MIN (L [i ], nmax )) - nmin ;
975
+ }
976
+ }
977
+
978
+ return scale ;
979
+ }
980
+
805
981
// non-linear exhaustive search with cumulative sums
806
982
// Need Faux to have room for n*k fractions
807
983
static float make_qkxs_nl_quants (int n , int k , const float * restrict x , const float * restrict weights , const int8_t * restrict kvalues , uint8_t * restrict L , uint8_t * restrict Laux , struct fraction * restrict Faux , bool signed_scale ) {
@@ -874,6 +1050,7 @@ static float make_qkxs_nl_quants(int n, int k, const float * restrict x, const f
874
1050
}
875
1051
876
1052
// Non-linear mappings are usually not symmetric, so try negating the scale
1053
+ // This is the same as above, but keeping the old best if the new best is not better.
877
1054
if (signed_scale ) {
878
1055
for (int i = 0 ; i < n ; ++ i ) {
879
1056
Laux [i ] = koff ;
@@ -1298,7 +1475,6 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
1298
1475
float amax = 0 ;
1299
1476
for (int j = 0 ; j < QK_K /16 ; ++ j ) {
1300
1477
scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weights , L + 16 * j , Laux , Faux , true);
1301
- // scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
1302
1478
float scale = fabsf (scales [j ]);
1303
1479
if (scale > amax ) {
1304
1480
amax = scale ; max_scale = scales [j ];
@@ -1324,21 +1500,6 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
1324
1500
y [i ].d = GGML_FP32_TO_FP16 (0.f );
1325
1501
}
1326
1502
1327
- // int8_t sc;
1328
- // for (int j = 0; j < QK_K/16; ++j) {
1329
- // sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
1330
- // sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
1331
- // float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1332
- // if (!d) {
1333
- // continue;
1334
- // }
1335
- // for (int ii = 0; ii < 16; ++ii) {
1336
- // int l = nearest_int(x[16*j + ii]/d);
1337
- // l = MAX(-4, MIN(3, l));
1338
- // L[16*j + ii] = l + 4;
1339
- // }
1340
- // }
1341
-
1342
1503
memset (y [i ].hmask , 0 , QK_K /8 );
1343
1504
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
1344
1505
int m = 0 ;
@@ -1441,14 +1602,12 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
1441
1602
for (int l = 0 ; l < 16 ; ++ l ) sumw += weight [l ];
1442
1603
sw [j ] = sumw ;
1443
1604
1444
- // scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
1445
1605
scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weight , L + 16 * j , Laux , Faux , true);
1446
1606
1447
1607
}
1448
1608
1449
1609
memset (y [i ].scales , 0 , 12 );
1450
1610
1451
- // float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
1452
1611
float d_block = make_qkxs_quants (QK_K /16 , -32 , 31 , scales , sw , Ls , Laux , Faux , true);
1453
1612
for (int j = 0 ; j < QK_K /16 ; ++ j ) {
1454
1613
int l = Ls [j ];
@@ -1462,21 +1621,6 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
1462
1621
}
1463
1622
y [i ].d = GGML_FP32_TO_FP16 (d_block );
1464
1623
1465
- // int8_t sc;
1466
- // for (int j = 0; j < QK_K/16; ++j) {
1467
- // sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
1468
- // sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
1469
- // float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1470
- // if (!d) {
1471
- // continue;
1472
- // }
1473
- // for (int ii = 0; ii < 16; ++ii) {
1474
- // int l = nearest_int(x[16*j + ii]/d);
1475
- // l = MAX(-4, MIN(3, l));
1476
- // L[16*j + ii] = l + 4;
1477
- // }
1478
- // }
1479
-
1480
1624
memset (y [i ].hmask , 0 , QK_K /8 );
1481
1625
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
1482
1626
int m = 0 ;
@@ -2526,7 +2670,7 @@ static void quantize_row_tq2_0_impl(const float * restrict x, block_tq2_0 * rest
2526
2670
const float * xb = x + QK_K * ib ;
2527
2671
const float * qw = quant_weights + QK_K * ib ;
2528
2672
for (int j = 0 ; j < QK_K ; ++ j ) { weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]); }
2529
- float d = make_qkxs_quants (QK_K , -1 , 2 , xb , weight , L , Laux , Faux , true );
2673
+ float d = make_qkxss_quants (QK_K , -1 , 2 , xb , weight , L , Laux , Faux );
2530
2674
y [ib ].d = GGML_FP32_TO_FP16 (d );
2531
2675
2532
2676
for (size_t j = 0 ; j < sizeof (y -> qs ); j += 32 ) {
0 commit comments