Skip to content

Commit d0060fc

Browse files
committed
ggml-quants : better and faster make_qkxs_quants
1 parent dd6b840 commit d0060fc

File tree

1 file changed

+116
-63
lines changed

1 file changed

+116
-63
lines changed

ggml/src/ggml-quants.c

Lines changed: 116 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -660,97 +660,148 @@ static inline int compare_fractions_desc(const void * a, const void * b) {
660660

661661
// exhaustive search with cumulative sums
662662
// Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions
663-
static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, struct fraction * restrict Faux, bool signed_scale) {
664-
float max = 0.0f;
665-
float amax = 0.0f;
666-
for (int i = 0; i < n; ++i) {
667-
float ax = fabsf(x[i]);
668-
if (ax > amax) {
669-
amax = ax;
670-
max = x[i];
663+
static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct fraction * restrict Faux, bool signed_scale) {
664+
const int orig_nmin = nmin;
665+
const int orig_nmax = nmax;
666+
float max = x[0];
667+
float min = x[0];
668+
float w_amax = weights[0] * fabsf(x[0]);
669+
int max_i = 0;
670+
int w_amax_i = 0;
671+
int min_i = 0;
672+
for (int i = 1; i < n; ++i) {
673+
if (x[i] < min) { min = x[i]; min_i = i; }
674+
if (x[i] > max) { max = x[i]; max_i = i; }
675+
// Find the most important value
676+
const float w = weights[i];
677+
const float wax = w * fabsf(x[i]);
678+
if (wax > w_amax) {
679+
w_amax = wax;
680+
w_amax_i = i;
681+
}
682+
}
683+
const int amax_i = fabsf(min) > fabsf(max) ? min_i : max_i;
684+
const float amax = fabsf(x[amax_i]);
685+
686+
if (amax < GROUP_MAX_EPS) { // all zero
687+
for (int i = 0; i < n; ++i) {
688+
L[i] = 0;
671689
}
690+
return 0.0f;
672691
}
673692
bool negative_scale = false;
674693
if (signed_scale && -nmin != nmax) {
675694
// the max side should have the biggest range
676-
if ((max < 0.0f) == (-nmin < nmax)) {
695+
// FIXME: this is incorrect when the weights[.] do not sort in the same order as fabsf(x[.])
696+
// or is it some other condition?
697+
if ((x[amax_i] < 0.0f) == (-nmin < nmax)) {
677698
// [-4, 3] ==> [-3, 4]
678-
int tmp = nmin;
699+
const int tmp = nmin;
700+
const float ftmp = min;
679701
nmin = -nmax;
680702
nmax = -tmp;
703+
min = -max;
704+
max = -ftmp;
681705
negative_scale = true;
682706
}
683707
}
684-
if (amax < GROUP_MAX_EPS) { // all zero
708+
709+
// Find the max range in [0, amax_range] which doesn't result in clamping.
710+
// This is the range from the side which would clamp first (biggest ratio of max to nmax).
711+
int amax_range;
712+
float range_max;
713+
if (fabsf(-max * nmin) < fabsf(-min * nmax)) {
714+
amax_range = MAX(0, -nmin);
715+
range_max = fabsf(min);
716+
} else {
717+
amax_range = MAX(0, nmax);
718+
range_max = fabsf(max);
719+
}
720+
float sumlx = 0.0f;
721+
float suml2 = 0.0f;
722+
float scale = 0.0f;
723+
float best = 0.0f;
724+
float best_denom = 1.0f;
725+
if (amax_range > 1) {
726+
// The smallest non-redundant iscale makes the first clamped value half+1 its max integer value.
727+
// Proof: anything smaller has a representable vector with values twice as big.
728+
const float iscale = ((float)(amax_range / 2 + 1))/range_max * (negative_scale ? -1.0f : 1.0f);
729+
for (int i = 0; i < n; ++i) {
730+
const float w = weights[i];
731+
int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
732+
if (negative_scale) { l = -l; }
733+
Laux[i] = l;
734+
L[i] = l;
735+
suml2 += w * l * l;
736+
sumlx += w * l * x[i];
737+
}
738+
best = sumlx * sumlx;
739+
best_denom = suml2; // should never be zero
740+
scale = sumlx / suml2;
741+
} else {
685742
for (int i = 0; i < n; ++i) {
743+
Laux[i] = 0;
686744
L[i] = 0;
687745
}
688-
return 0.0f;
689746
}
747+
748+
const int imax_range = MAX(0, (x[w_amax_i] < 0.0f) ? -nmin : nmax);
749+
const int max_odd = 2*(imax_range + 1) + 1;
750+
const float wmax = fabsf(x[w_amax_i]);
690751
int n_frac = 0;
691752
for (int i = 0; i < n; ++i) {
692753
// assuming nmin <= nmax
693-
const int odd_max = MAX(0, x[i] < 0 ? -nmin : nmax);
694-
const int odd_min = MAX(0, x[i] < 0 ? -nmax : nmin);
754+
const int odd_max = MAX(abs(Laux[i]), x[i] < 0.0f ? -nmin : nmax);
755+
const int odd_min = MAX(abs(Laux[i]), x[i] < 0.0f ? -nmax : nmin);
695756
const float v = fabsf(x[i]);
696-
// fprintf(stderr, "%s: i=%d, odd_min=%d, odd_max=%d\n", __func__, i, odd_min, odd_max);
757+
const float v_max_odd = v * max_odd;
697758
for (int j = odd_min; j < odd_max; ++j) {
698759
const float odd = 2*j + 1;
699-
Faux[n_frac++] = (struct fraction){
700-
.numer=v,
701-
.denom=odd,
702-
.i=i,
703-
};
760+
if (wmax * odd < v_max_odd) {
761+
Faux[n_frac++] = (struct fraction){
762+
.numer=v,
763+
.denom=odd,
764+
.i=i,
765+
};
766+
} else {
767+
// stop when the inverse scale would result in clamping the max (FIXME: most important) value
768+
break;
769+
}
704770
}
705771
}
706772

707773
qsort(Faux, n_frac, sizeof(struct fraction), compare_fractions_desc);
708774

709-
float iscale = 0.0f;
710-
{
711-
float sumlx = 0.0f;
712-
float suml2 = 0.0f;
713-
float best = 0.0f;
714-
float best_denom = 1.0f;
715-
for (int i = 0; i < n_frac; ++i) {
716-
// maximize the weighted cosine
717-
const int ii = Faux[i].i;
718-
const float w = weights ? weights[ii] : x[ii] * x[ii];
719-
sumlx += w * Faux[i].numer;
720-
suml2 += w * Faux[i].denom;
721-
const float current = sumlx * sumlx;
722-
// fprintf(stderr, "%s: Faux[%d]=(%f/%f) * %f, square(sumlx)=%f, suml2=%f, k*cos2=%f\n", __func__, i, Faux[i].numer, Faux[i].denom, Faux[i].weight, current, suml2, current / suml2);
723-
// use the last in case of equality
724-
// FIXME: > or >= ?? Why does [0, 0, 1] rounds to [0, 0, 0] with >= ?
725-
if (suml2 > 0.0f && current * best_denom > best * suml2) {
726-
best = current;
727-
best_denom = suml2;
728-
iscale = Faux[i].numer > 0.0f ? Faux[i].denom / (2.0f * Faux[i].numer) : 0.0f;
729-
if (!isfinite(iscale)) {
730-
fprintf(stderr, "%s: iscale is not finite, %f/(2*%f)\n", __func__, Faux[i].denom, Faux[i].numer);
775+
int best_p_i = -1; // consecutive with 0..n_frac
776+
for (int i = 0; i < n_frac; ++i) {
777+
// maximize the weighted cosine
778+
const int ii = Faux[i].i;
779+
const float w = weights ? weights[ii] : x[ii] * x[ii];
780+
sumlx += w * Faux[i].numer;
781+
suml2 += w * Faux[i].denom;
782+
const float current = sumlx * sumlx;
783+
Laux[ii] += x[ii] < 0.0f ? -1 : 1;
784+
if (suml2 > 0.0f && Faux[i].numer > 0.0f && current * best_denom > best * suml2) {
785+
best = current;
786+
best_denom = suml2;
787+
scale = sumlx / suml2;
788+
if (i == best_p_i + 1) {
789+
// reduce copies for consecutive bests
790+
L[ii] += x[ii] < 0.0f ? -1 : 1;
791+
} else {
792+
for (int j = 0; j < n; ++j) {
793+
L[j] = Laux[j];
731794
}
732795
}
796+
best_p_i = i;
733797
}
734798
}
735-
// (very) small fudging necessary because floats otherwise round to nearest even
736-
iscale = iscale * ((float)((1 << 23) + 1) / (float)(1 << 23));
737-
738-
float sumlx = 0.0f;
739-
float suml2 = 0.0f;
740799
for (int i = 0; i < n; ++i) {
741-
// Rounding away from zero is assumed by the search algorithm above.
742-
int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
743-
if (negative_scale) {
744-
l = -l;
745-
}
746-
L[i] = negative_scale ? l + nmax : l - nmin;
747-
float w = weights ? weights[i] : x[i] * x[i];
748-
// weighted projection scale
749-
sumlx += w * x[i] * l;
750-
suml2 += w * l * l;
800+
L[i] = negative_scale ? (-L[i] + nmax) : (L[i] + -nmin);
801+
GGML_ASSERT(L[i] >= 0 && L[i] <= nmax - nmin);
751802
}
752803

753-
return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
804+
return negative_scale ? -scale : scale;
754805
}
755806

756807
// non-linear exhaustive search with cumulative sums
@@ -1234,6 +1285,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
12341285
const int nb = k / QK_K;
12351286

12361287
int8_t L[QK_K];
1288+
int8_t Laux[16];
12371289
struct fraction Faux[16 * 4];
12381290
float scales[QK_K / 16];
12391291
float weights[16];
@@ -1247,7 +1299,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
12471299
float max_scale = 0;
12481300
float amax = 0;
12491301
for (int j = 0; j < QK_K/16; ++j) {
1250-
scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weights, L + 16*j, Faux, true);
1302+
scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weights, L + 16*j, Laux, Faux, true);
12511303
// scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
12521304
float scale = fabsf(scales[j]);
12531305
if (scale > amax) {
@@ -1367,6 +1419,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
13671419
const int nb = n_per_row / QK_K;
13681420

13691421
int8_t L[QK_K];
1422+
int8_t Laux[16];
13701423
float scales[QK_K / 16];
13711424
float weight[16];
13721425
float sw[QK_K / 16];
@@ -1391,14 +1444,14 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
13911444
sw[j] = sumw;
13921445

13931446
// scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
1394-
scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weight, L + 16*j, Faux, true);
1447+
scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weight, L + 16*j, Laux, Faux, true);
13951448

13961449
}
13971450

13981451
memset(y[i].scales, 0, 12);
13991452

14001453
// float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
1401-
float d_block = make_qkxs_quants(QK_K/16, -32, 31, scales, sw, Ls, Faux, true);
1454+
float d_block = make_qkxs_quants(QK_K/16, -32, 31, scales, sw, Ls, Laux, Faux, true);
14021455
for (int j = 0; j < QK_K/16; ++j) {
14031456
int l = Ls[j];
14041457
if (j < 8) {
@@ -4856,11 +4909,11 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48564909
for (int j = 0; j < block_size; ++j) weight[j] = sqrtf(sigma2 + xb[j]*xb[j]);
48574910
// for (int j = 0; j < block_size; ++j) weight[j] = 1;
48584911
}
4859-
float amax = 0, max = 0;
4912+
float amax = 0;
48604913
for (int j = 0; j < block_size; ++j) {
48614914
float ax = fabsf(xb[j]);
48624915
if (ax > amax) {
4863-
amax = ax; max = xb[j];
4916+
amax = ax;
48644917
}
48654918
}
48664919
if (amax < GROUP_MAX_EPS) {

0 commit comments

Comments
 (0)