@@ -73,11 +73,15 @@ static int sched_yield (void) {
73
73
Sleep (0 );
74
74
return 0 ;
75
75
}
76
+
77
+ #define __attribute__ (...)
76
78
#else
77
79
#include <pthread.h>
78
80
#include <stdatomic.h>
79
81
80
82
typedef void * thread_ret_t ;
83
+
84
+ #define __declspec (...)
81
85
#endif
82
86
83
87
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
@@ -517,39 +521,120 @@ typedef struct {
517
521
static_assert (sizeof (block_q4_1 ) == sizeof (float ) * 2 + QK / 2 , "wrong q4_1 block size/padding" );
518
522
519
523
// reference implementation for deterministic creation of model files
520
- static void quantize_row_q4_0_reference (const float * restrict x , block_q4_0 * restrict y , int k ) {
521
- assert (k % QK == 0 );
522
- const int nb = k / QK ;
523
-
524
+ static inline void quantize_block_q4_0_reference (const float * restrict x , block_q4_0 * restrict y , float scale ) {
524
525
uint8_t pp [QK /2 ];
525
526
526
- for ( int i = 0 ; i < nb ; i ++ ) {
527
- float amax = 0.0f ; // absolute max
527
+ float amax = 0.0f ; // absolute max
528
+ float max = 0.0f ;
528
529
529
- for (int l = 0 ; l < QK ; l ++ ) {
530
- const float v = x [i * QK + l ];
531
- amax = MAX (amax , fabsf (v ));
530
+ for (int l = 0 ; l < QK ; l ++ ) {
531
+ const float v = x [l ];
532
+ if (amax < fabsf (v )) {
533
+ amax = fabsf (v );
534
+ max = v ;
532
535
}
536
+ }
533
537
534
- const float d = amax / (( 1 << 3 ) - 1 ) ;
535
- const float id = d ? 1.0f /d : 0.0f ;
538
+ const float d = max / scale ;
539
+ const float id = d ? 1.0f /d : 0.0f ;
536
540
537
- y [ i ]. d = d ;
541
+ y -> d = d ;
538
542
539
- for (int l = 0 ; l < QK ; l += 2 ) {
540
- const float v0 = x [i * QK + l + 0 ]* id ;
541
- const float v1 = x [i * QK + l + 1 ]* id ;
543
+ for (int l = 0 ; l < QK ; l += 2 ) {
544
+ const float v0 = x [l + 0 ]* id ;
545
+ const float v1 = x [l + 1 ]* id ;
542
546
543
- const uint8_t vi0 = ( int8_t ) roundf (v0 ) + 8 ;
544
- const uint8_t vi1 = ( int8_t ) roundf (v1 ) + 8 ;
547
+ int8_t vs0 = roundf (v0 );
548
+ int8_t vs1 = roundf (v1 );
545
549
546
- assert ( vi0 < 16 );
547
- assert ( vi1 < 16 );
550
+ vs0 = MIN ( MAX ( 0 - 8 , vs0 ), 15 - 8 );
551
+ vs1 = MIN ( MAX ( 0 - 8 , vs1 ), 15 - 8 );
548
552
549
- pp [l /2 ] = vi0 | (vi1 << 4 );
553
+ const uint8_t vi0 = vs0 + 8 ; // guaranteed to fit into 4 bits
554
+ const uint8_t vi1 = vs1 + 8 ; // thanks to the clamping of the signed values above
555
+
556
+ pp [l /2 ] = vi0 | (vi1 << 4 );
557
+ }
558
+
559
+ memcpy (y -> qs , pp , sizeof (pp ));
560
+ }
561
+
562
+ static void quantize_row_q4_0_rmse (const float * restrict x , block_q4_0 * restrict y , int k ) {
563
+ // For each q4_0 block, we try the following values to scale the shared float value
564
+ // and pick the one with lowest RMS error. We could do a more involved search,
565
+ // but this is a trade-off with speed of model generation and simplicity of the code.
566
+ // Operating on 8 values can reasonably be loop-unrolled or vectorized, but that is not
567
+ // manually done here.
568
+ // Values hand-picked according to histogram in examples/quantize/scale.py
569
+ // Include the value +7 of the old method to ensure we don't regress on RMSE on any block.
570
+ #define Q4_0_SCALE_CANDIDATE_COUNT 8
571
+ static const float candidates [Q4_0_SCALE_CANDIDATE_COUNT ] = { -8.7f , -8.5f , -8.3f , -8.1f , -7.9f , -7.7f , -7.2f , +7.0f };
572
+
573
+ assert (k % QK == 0 );
574
+ const int nb = k / QK ;
575
+
576
+ for (int i = 0 ; i < nb ; i ++ ) {
577
+ float amax = 0.0f ; // absolute max
578
+ float max = 0.0f ;
579
+
580
+ for (int l = 0 ; l < QK ; l ++ ) {
581
+ const float v = x [i * QK + l ];
582
+ if (amax < fabsf (v )) {
583
+ amax = fabsf (v );
584
+ max = v ;
585
+ }
550
586
}
551
587
552
- memcpy (y [i ].qs , pp , sizeof (pp ));
588
+ // find scale with lowest sum of squared errors, equivalent to lowest RMS error
589
+ float best_sqerr = + INFINITY ;
590
+ float best_scale = NAN ;
591
+
592
+ for (int si = 0 ; si < Q4_0_SCALE_CANDIDATE_COUNT ; si ++ ) {
593
+ const float scale = candidates [si ];
594
+ const float d = max / scale ;
595
+ const float id = d ? 1.0f / d : 0.0f ;
596
+ float sqe_acc = 0.f ;
597
+ #ifdef __AVX2__
598
+ const __m256 clamp_lo = _mm256_set1_ps ( 0 - 8 );
599
+ const __m256 clamp_hi = _mm256_set1_ps (15 - 8 );
600
+ const __m256 id256 = _mm256_set1_ps (id );
601
+ for (int l = 0 ; l < QK ; l += 8 ) {
602
+ // TODO: why are the inputs not aligned to 32 bytes?
603
+ __m256 v = _mm256_loadu_ps (& x [i * QK + l ]);
604
+ v = _mm256_mul_ps (v , id256 );
605
+ __m256 vs = _mm256_round_ps (v , _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC );
606
+ vs = _mm256_min_ps (_mm256_max_ps (clamp_lo , vs ), clamp_hi );
607
+ const __m256 err = _mm256_sub_ps (vs , v );
608
+ const __m256 sqe = _mm256_mul_ps (err , err );
609
+
610
+ // this is far from optimal speed-wise, but ensures identical results to scalar implementation
611
+ // we have to add the floats in sqe to sqe_acc separately and in the correct order
612
+ // 8x _mm_add_ps(,_mm_permute_ps()) would work but isn't faster than this:
613
+ __declspec(align (32 )) float out [8 ] __attribute__((aligned (32 )));
614
+ _mm256_store_ps (out , sqe );
615
+ for (int ei = 0 ; ei < 8 ; ei ++ ) {
616
+ sqe_acc += out [ei ];
617
+ }
618
+ }
619
+ #else
620
+ for (int l = 0 ; l < QK ; l ++ ) {
621
+ const float v = x [i * QK + l ] * id ;
622
+ int8_t vs = roundf (v );
623
+ vs = MIN (MAX (0 - 8 , vs ), 15 - 8 );
624
+ sqe_acc += (vs - v ) * (vs - v );
625
+ }
626
+ #endif
627
+ // the square error sum is calculated on un-scaled q's inside the inner loop
628
+ sqe_acc *= d * d ;
629
+
630
+ if (best_sqerr > sqe_acc ) {
631
+ best_sqerr = sqe_acc ;
632
+ best_scale = scale ;
633
+ }
634
+ }
635
+ assert (isfinite (best_sqerr ));
636
+ assert (isfinite (best_scale ));
637
+ quantize_block_q4_0_reference (x + i * QK , y + i , best_scale );
553
638
}
554
639
}
555
640
@@ -803,7 +888,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
803
888
}
804
889
#else
805
890
// scalar
806
- quantize_row_q4_0_reference (x , y , k );
891
+ for (int i = 0 ; i < nb ; i ++ ) {
892
+ quantize_block_q4_0_reference (x + i * QK , y + i , 7 );
893
+ }
807
894
#endif
808
895
}
809
896
@@ -10604,7 +10691,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
10604
10691
for (int j = 0 ; j < n ; j += k ) {
10605
10692
block_q4_0 * restrict y = (block_q4_0 * )dst + j /QK ;
10606
10693
10607
- quantize_row_q4_0_reference (src + j , y , k );
10694
+ quantize_row_q4_0_rmse (src + j , y , k );
10608
10695
10609
10696
for (int i = 0 ; i < nb ; i ++ ) {
10610
10697
for (int l = 0 ; l < QK ; l += 2 ) {
0 commit comments