@@ -661,8 +661,6 @@ static inline int compare_fractions_desc(const void * a, const void * b) {
661
661
// exhaustive search with cumulative sums
662
662
// Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions
663
663
static float make_qkxs_quants (int n , int nmin , int nmax , const float * restrict x , const float * restrict weights , int8_t * restrict L , int8_t * restrict Laux , struct fraction * restrict Faux , bool signed_scale ) {
664
- const int orig_nmin = nmin ;
665
- const int orig_nmax = nmax ;
666
664
float max = x [0 ];
667
665
float min = x [0 ];
668
666
float w_amax = weights [0 ] * fabsf (x [0 ]);
@@ -2143,6 +2141,8 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
2143
2141
2144
2142
float weight [QK4_0 ];
2145
2143
int8_t L [QK4_0 ];
2144
+ int8_t Laux [QK4_0 ];
2145
+ struct fraction Faux [8 * QK4_0 ];
2146
2146
2147
2147
float sum_x2 = 0 ;
2148
2148
for (int j = 0 ; j < n_per_row ; ++ j ) sum_x2 += x [j ]* x [j ];
@@ -2153,7 +2153,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
2153
2153
const float * xb = x + QK4_0 * ib ;
2154
2154
const float * qw = quant_weights + QK4_0 * ib ;
2155
2155
for (int j = 0 ; j < QK4_0 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
2156
- float d = make_qx_quants (QK4_0 , 8 , xb , L , 1 , weight );
2156
+ float d = make_qkxs_quants (QK4_0 , - 8 , 7 , xb , weight , L , Laux , Faux , true );
2157
2157
y [ib ].d = GGML_FP32_TO_FP16 (d );
2158
2158
for (int j = 0 ; j < 16 ; ++ j ) {
2159
2159
y [ib ].qs [j ] = L [j ] | (L [j + 16 ] << 4 );
@@ -2231,6 +2231,8 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
2231
2231
2232
2232
float weight [QK5_0 ];
2233
2233
int8_t L [QK5_0 ];
2234
+ int8_t Laux [QK5_0 ];
2235
+ struct fraction Faux [16 * QK5_0 ];
2234
2236
2235
2237
float sum_x2 = 0 ;
2236
2238
for (int j = 0 ; j < n_per_row ; ++ j ) sum_x2 += x [j ]* x [j ];
@@ -2241,7 +2243,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
2241
2243
const float * xb = x + QK5_0 * ib ;
2242
2244
const float * qw = quant_weights + QK5_0 * ib ;
2243
2245
for (int j = 0 ; j < QK5_0 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
2244
- float d = make_qx_quants (QK5_0 , 16 , xb , L , 1 , weight );
2246
+ float d = make_qkxs_quants (QK5_0 , - 16 , 15 , xb , weight , L , Laux , Faux , true );
2245
2247
y [ib ].d = GGML_FP32_TO_FP16 (d );
2246
2248
2247
2249
uint32_t qh = 0 ;
@@ -2403,6 +2405,74 @@ void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y,
2403
2405
}
2404
2406
}
2405
2407
2408
+ static void quantize_row_tq1_0_impl (const float * restrict x , block_tq1_0 * restrict y , int64_t n_per_row , const float * quant_weights ) {
2409
+ if (!quant_weights ) {
2410
+ quantize_row_tq1_0_ref (x , y , n_per_row );
2411
+ return ;
2412
+ }
2413
+
2414
+ float weight [QK_K ];
2415
+ int8_t L [QK_K ];
2416
+ int8_t Laux [QK_K ];
2417
+ struct fraction Faux [1 * QK_K ];
2418
+
2419
+ float sum_x2 = 0 ;
2420
+ for (int j = 0 ; j < n_per_row ; ++ j ) { sum_x2 += x [j ]* x [j ]; }
2421
+ float sigma2 = sum_x2 /n_per_row ;
2422
+
2423
+ const int64_t nb = n_per_row /QK_K ;
2424
+ for (int ib = 0 ; ib < nb ; ++ ib ) {
2425
+ const float * xb = x + QK_K * ib ;
2426
+ const float * qw = quant_weights + QK_K * ib ;
2427
+ const int8_t * Lptr = L ;
2428
+ for (int j = 0 ; j < QK_K ; ++ j ) { weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]); }
2429
+ float d = make_qkxs_quants (QK_K , -1 , 1 , xb , weight , L , Laux , Faux , false);
2430
+ y [ib ].d = GGML_FP32_TO_FP16 (d );
2431
+
2432
+ // 5 elements per byte, along 32 bytes
2433
+ for (size_t j = 0 ; j < sizeof (y -> qs ) - sizeof (y -> qs ) % 32 ; j += 32 ) {
2434
+ for (size_t m = 0 ; m < 32 ; ++ m ) {
2435
+ uint8_t q = 0 ;
2436
+ for (size_t n = 0 ; n < 5 ; ++ n ) {
2437
+ q *= 3 ;
2438
+ q += Lptr [m + n * 32 ];
2439
+ }
2440
+ // ceiling division (243 == pow(3, 5))
2441
+ q = ((uint16_t )q * 256 + (243 - 1 )) / 243 ;
2442
+ y [ib ].qs [j + m ] = q ;
2443
+ }
2444
+ Lptr += 5 * 32 ;
2445
+ }
2446
+ // along 16 bytes
2447
+ for (size_t j = sizeof (y -> qs ) - sizeof (y -> qs ) % 32 ; j < sizeof (y -> qs ); j += 16 ) {
2448
+ for (size_t m = 0 ; m < 16 ; ++ m ) {
2449
+ uint8_t q = 0 ;
2450
+ for (size_t n = 0 ; n < 5 ; ++ n ) {
2451
+ q *= 3 ;
2452
+ q += Lptr [m + n * 16 ];
2453
+ }
2454
+ // ceiling division (243 == pow(3, 5))
2455
+ q = ((uint16_t )q * 256 + (243 - 1 )) / 243 ;
2456
+ y [ib ].qs [j + m ] = q ;
2457
+ }
2458
+ Lptr += 5 * 16 ;
2459
+ }
2460
+ // 4 elements per byte
2461
+ for (size_t j = 0 ; j < sizeof (y -> qh ); ++ j ) {
2462
+ uint8_t q = 0 ;
2463
+ for (size_t m = 0 ; m < 4 ; ++ m ) {
2464
+ q *= 3 ;
2465
+ q += Lptr [j + m * sizeof (y -> qh )];
2466
+ }
2467
+ // shift the first value to the most significant trit
2468
+ q *= 3 ;
2469
+ // ceiling division (243 == pow(3, 5))
2470
+ q = ((uint16_t )q * 256 + (243 - 1 )) / 243 ;
2471
+ y [ib ].qh [j ] = q ;
2472
+ }
2473
+ }
2474
+ }
2475
+
2406
2476
void quantize_row_tq2_0_ref (const float * restrict x , block_tq2_0 * restrict y , int64_t k ) {
2407
2477
assert (k % QK_K == 0 );
2408
2478
const int64_t nb = k / QK_K ;
@@ -2435,17 +2505,69 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y,
2435
2505
}
2436
2506
}
2437
2507
2508
+
2509
+ static void quantize_row_tq2_0_impl (const float * restrict x , block_tq2_0 * restrict y , int64_t n_per_row , const float * quant_weights ) {
2510
+ if (!quant_weights ) {
2511
+ quantize_row_tq2_0_ref (x , y , n_per_row );
2512
+ return ;
2513
+ }
2514
+
2515
+ float weight [QK_K ];
2516
+ int8_t L [QK_K ];
2517
+ int8_t Laux [QK_K ];
2518
+ struct fraction Faux [2 * QK_K ];
2519
+
2520
+ float sum_x2 = 0 ;
2521
+ for (int j = 0 ; j < n_per_row ; ++ j ) { sum_x2 += x [j ]* x [j ]; }
2522
+ float sigma2 = sum_x2 /n_per_row ;
2523
+
2524
+ const int64_t nb = n_per_row /QK_K ;
2525
+ for (int ib = 0 ; ib < nb ; ++ ib ) {
2526
+ const float * xb = x + QK_K * ib ;
2527
+ const float * qw = quant_weights + QK_K * ib ;
2528
+ for (int j = 0 ; j < QK_K ; ++ j ) { weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]); }
2529
+ float d = make_qkxs_quants (QK_K , -1 , 2 , xb , weight , L , Laux , Faux , true);
2530
+ y [ib ].d = GGML_FP32_TO_FP16 (d );
2531
+
2532
+ for (size_t j = 0 ; j < sizeof (y -> qs ); j += 32 ) {
2533
+ for (size_t m = 0 ; m < 32 ; ++ m ) {
2534
+ uint8_t q = 0 ;
2535
+ for (size_t n = 0 ; n < 4 ; ++ n ) {
2536
+ q += (L [4 * j + m + n * 32 ] & 3 ) << (2 * n );
2537
+ }
2538
+ y [ib ].qs [j + m ] = q ;
2539
+ }
2540
+ }
2541
+ }
2542
+ }
2543
+
2438
2544
size_t quantize_tq1_0 (const float * restrict src , void * restrict dst , int64_t nrow , int64_t n_per_row , const float * quant_weights ) {
2439
- (void )quant_weights ; // not used
2440
- const size_t row_size = ggml_row_size (GGML_TYPE_TQ1_0 , n_per_row );
2441
- quantize_row_tq1_0_ref (src , dst , (int64_t )nrow * n_per_row );
2545
+ if (!quant_weights ) {
2546
+ quantize_row_tq1_0_ref (src , dst , (int64_t )nrow * n_per_row );
2547
+ return nrow * ggml_row_size (GGML_TYPE_TQ1_0 , n_per_row );
2548
+ }
2549
+ size_t row_size = ggml_row_size (GGML_TYPE_TQ1_0 , n_per_row );
2550
+ char * qrow = (char * )dst ;
2551
+ for (int64_t row = 0 ; row < nrow ; ++ row ) {
2552
+ quantize_row_tq1_0_impl (src , (block_tq1_0 * )qrow , n_per_row , quant_weights );
2553
+ src += n_per_row ;
2554
+ qrow += row_size ;
2555
+ }
2442
2556
return nrow * row_size ;
2443
2557
}
2444
2558
2445
2559
size_t quantize_tq2_0 (const float * restrict src , void * restrict dst , int64_t nrow , int64_t n_per_row , const float * quant_weights ) {
2446
- (void )quant_weights ; // not used
2447
- const size_t row_size = ggml_row_size (GGML_TYPE_TQ2_0 , n_per_row );
2448
- quantize_row_tq2_0_ref (src , dst , (int64_t )nrow * n_per_row );
2560
+ if (!quant_weights ) {
2561
+ quantize_row_tq2_0_ref (src , dst , (int64_t )nrow * n_per_row );
2562
+ return nrow * ggml_row_size (GGML_TYPE_TQ2_0 , n_per_row );
2563
+ }
2564
+ size_t row_size = ggml_row_size (GGML_TYPE_TQ2_0 , n_per_row );
2565
+ char * qrow = (char * )dst ;
2566
+ for (int64_t row = 0 ; row < nrow ; ++ row ) {
2567
+ quantize_row_tq2_0_impl (src , (block_tq2_0 * )qrow , n_per_row , quant_weights );
2568
+ src += n_per_row ;
2569
+ qrow += row_size ;
2570
+ }
2449
2571
return nrow * row_size ;
2450
2572
}
2451
2573
0 commit comments