@@ -592,7 +592,7 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 bloc
592
592
593
593
typedef struct {
594
594
float d ; // delta
595
- uint8_t qs [QK ]; // nibbles / quants
595
+ int8_t qs [QK ]; // quants
596
596
} block_q8_0 ;
597
597
static_assert (sizeof (block_q8_0 ) == sizeof (float ) + QK , "wrong q8_0 block size/padding" );
598
598
@@ -1069,9 +1069,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1069
1069
1070
1070
for (int l = 0 ; l < QK ; ++ l ) {
1071
1071
const float v = x [i * QK + l ]* id ;
1072
- const uint8_t vi = (int8_t )roundf (v ) + 128 ;
1073
-
1074
- y [i ].qs [l ] = vi ;
1072
+ y [i ].qs [l ] = roundf (v );
1075
1073
}
1076
1074
}
1077
1075
}
@@ -1104,15 +1102,99 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1104
1102
1105
1103
for (int l = 0 ; l < 8 ; l ++ ) {
1106
1104
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1107
- const float32x4_t vf = vaddq_f32 ( v , vdupq_n_f32 ( 128.5f ));
1108
- const int32x4_t vi = vcvtq_s32_f32 (vf );
1105
+ //TODO: rounding
1106
+ const int32x4_t vi = vcvtq_s32_f32 (v );
1109
1107
1110
1108
y [i ].qs [4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
1111
1109
y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1112
1110
y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1113
1111
y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1114
1112
}
1115
1113
}
1114
+ #elif defined(__AVX2__ ) || defined(__AVX__ )
1115
+ for (int i = 0 ; i < nb ; i ++ ) {
1116
+ // Load elements into 4 AVX vectors
1117
+ __m256 v0 = _mm256_loadu_ps ( x );
1118
+ __m256 v1 = _mm256_loadu_ps ( x + 8 );
1119
+ __m256 v2 = _mm256_loadu_ps ( x + 16 );
1120
+ __m256 v3 = _mm256_loadu_ps ( x + 24 );
1121
+ x += 32 ;
1122
+
1123
+ // Compute max(abs(e)) for the block
1124
+ const __m256 signBit = _mm256_set1_ps ( -0.0f );
1125
+ __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
1126
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
1127
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
1128
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
1129
+
1130
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
1131
+ max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
1132
+ max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
1133
+ const float maxScalar = _mm_cvtss_f32 ( max4 );
1134
+
1135
+ // Quantize these floats
1136
+ const float d = maxScalar / 127.f ;
1137
+ y [i ].d = d ;
1138
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f ;
1139
+ const __m256 mul = _mm256_set1_ps ( id );
1140
+
1141
+ // Apply the multiplier
1142
+ v0 = _mm256_mul_ps ( v0 , mul );
1143
+ v1 = _mm256_mul_ps ( v1 , mul );
1144
+ v2 = _mm256_mul_ps ( v2 , mul );
1145
+ v3 = _mm256_mul_ps ( v3 , mul );
1146
+
1147
+ // Round to nearest integer
1148
+ v0 = _mm256_round_ps ( v0 , _MM_ROUND_NEAREST );
1149
+ v1 = _mm256_round_ps ( v1 , _MM_ROUND_NEAREST );
1150
+ v2 = _mm256_round_ps ( v2 , _MM_ROUND_NEAREST );
1151
+ v3 = _mm256_round_ps ( v3 , _MM_ROUND_NEAREST );
1152
+
1153
+ // Convert floats to integers
1154
+ __m256i i0 = _mm256_cvtps_epi32 ( v0 );
1155
+ __m256i i1 = _mm256_cvtps_epi32 ( v1 );
1156
+ __m256i i2 = _mm256_cvtps_epi32 ( v2 );
1157
+ __m256i i3 = _mm256_cvtps_epi32 ( v3 );
1158
+
1159
+ #if defined(__AVX2__ )
1160
+ // Convert int32 to int16
1161
+ i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1162
+ i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1163
+ // Convert int16 to int8
1164
+ i0 = _mm256_packs_epi16 ( i0 , i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1165
+
1166
+ // We got our precious signed bytes, but the order is now wrong
1167
+ // These AVX2 pack instructions process 16-byte pieces independently
1168
+ // The following instruction is fixing the order
1169
+ const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
1170
+ i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
1171
+
1172
+ _mm256_storeu_si256 ((__m256i * )y [i ].qs , i0 );
1173
+ #else
1174
+ // Since we don't have in AVX some necessary functions,
1175
+ // we split the registers in half and call AVX2 analogs from SSE
1176
+ __m128i ni0 = _mm256_castsi256_si128 ( i0 );
1177
+ __m128i ni1 = _mm256_extractf128_si256 ( i0 , 1 );
1178
+ __m128i ni2 = _mm256_castsi256_si128 ( i1 );
1179
+ __m128i ni3 = _mm256_extractf128_si256 ( i1 , 1 );
1180
+ __m128i ni4 = _mm256_castsi256_si128 ( i2 );
1181
+ __m128i ni5 = _mm256_extractf128_si256 ( i2 , 1 );
1182
+ __m128i ni6 = _mm256_castsi256_si128 ( i3 );
1183
+ __m128i ni7 = _mm256_extractf128_si256 ( i3 , 1 );
1184
+
1185
+ // Convert int32 to int16
1186
+ ni0 = _mm_packs_epi32 ( ni0 , ni1 );
1187
+ ni2 = _mm_packs_epi32 ( ni2 , ni3 );
1188
+ ni4 = _mm_packs_epi32 ( ni4 , ni5 );
1189
+ ni6 = _mm_packs_epi32 ( ni6 , ni7 );
1190
+ // Convert int16 to int8
1191
+ ni0 = _mm_packs_epi16 ( ni0 , ni2 );
1192
+ ni4 = _mm_packs_epi16 ( ni4 , ni6 );
1193
+
1194
+ _mm_storeu_si128 ((__m128i * )(y [i ].qs + 0 ), ni0 );
1195
+ _mm_storeu_si128 ((__m128i * )(y [i ].qs + 16 ), ni4 );
1196
+ #endif
1197
+ }
1116
1198
#else
1117
1199
// scalar
1118
1200
quantize_row_q8_0_reference (x , y , k );
@@ -2517,7 +2599,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2517
2599
2518
2600
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2519
2601
const int8x16_t s8b = vdupq_n_s8 (0x8 );
2520
- const uint8x16_t u128b = vdupq_n_u8 (128 );
2521
2602
2522
2603
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2523
2604
const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2535,21 +2616,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2535
2616
const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2536
2617
2537
2618
// load y
2538
- const uint8x16_t v1_0l = vld1q_u8 (y0 -> qs );
2539
- const uint8x16_t v1_0h = vld1q_u8 (y0 -> qs + 16 );
2540
- const uint8x16_t v1_1l = vld1q_u8 (y1 -> qs );
2541
- const uint8x16_t v1_1h = vld1q_u8 (y1 -> qs + 16 );
2619
+ const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2620
+ const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2621
+ const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2622
+ const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2542
2623
2543
2624
// interleave
2544
- const uint8x16_t v1_0lz = vuzp1q_u8 (v1_0l , v1_0h );
2545
- const uint8x16_t v1_0hz = vuzp2q_u8 (v1_0l , v1_0h );
2546
- const uint8x16_t v1_1lz = vuzp1q_u8 (v1_1l , v1_1h );
2547
- const uint8x16_t v1_1hz = vuzp2q_u8 (v1_1l , v1_1h );
2548
-
2549
- const int8x16_t v1_0ls = vreinterpretq_s8_u8 (vsubq_u8 (v1_0lz , u128b ));
2550
- const int8x16_t v1_0hs = vreinterpretq_s8_u8 (vsubq_u8 (v1_0hz , u128b ));
2551
- const int8x16_t v1_1ls = vreinterpretq_s8_u8 (vsubq_u8 (v1_1lz , u128b ));
2552
- const int8x16_t v1_1hs = vreinterpretq_s8_u8 (vsubq_u8 (v1_1hz , u128b ));
2625
+ const int8x16_t v1_0ls = vuzp1q_s8 (v1_0l , v1_0h );
2626
+ const int8x16_t v1_0hs = vuzp2q_s8 (v1_0l , v1_0h );
2627
+ const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2628
+ const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2553
2629
2554
2630
#if defined(__ARM_FEATURE_DOTPROD )
2555
2631
// dot product into int32x4_t
@@ -2587,14 +2663,102 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2587
2663
}
2588
2664
2589
2665
sumf = sum0 + sum1 ;
2666
+ #elif defined(__AVX2__ )
2667
+ // Initialize accumulator with zeros
2668
+ __m256 acc = _mm256_setzero_ps ();
2669
+
2670
+ // Main loop
2671
+ for (int i = 0 ; i < nb ; ++ i ) {
2672
+ /* Compute combined scale for the block */
2673
+ const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2674
+
2675
+ __m256i bx = bytesFromNibbles (x [i ].qs );
2676
+
2677
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2678
+ const __m256i off = _mm256_set1_epi8 ( 8 );
2679
+ bx = _mm256_sub_epi8 ( bx , off );
2680
+
2681
+ __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2682
+
2683
+ // Get absolute values of x vectors
2684
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2685
+
2686
+ // Sign the values of the y vectors
2687
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2688
+
2689
+ // Perform multiplication and create 16-bit values
2690
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2691
+
2692
+ const __m256i ones = _mm256_set1_epi16 (1 );
2693
+ __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2694
+
2695
+ /* Convert to vectore of 8 int32_t to 8 floats */
2696
+ __m256 q = _mm256_cvtepi32_ps ( xy_q );
2697
+
2698
+ /* Multiply q with scale and accumulate */
2699
+ acc = _mm256_fmadd_ps ( d , q , acc );
2700
+ }
2701
+
2702
+ // Return horizontal sum of the acc vector
2703
+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
2704
+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2705
+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2706
+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2707
+
2708
+ sumf = _mm_cvtss_f32 ( res );
2709
+ #elif defined(__AVX__ )
2710
+ // Initialize accumulator with zeros
2711
+ __m256 acc = _mm256_setzero_ps ();
2712
+
2713
+ // Main loop
2714
+ for (int i = 0 ; i < nb ; ++ i ) {
2715
+ // Compute combined scale for the block
2716
+ const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2717
+
2718
+ __m128i i32 [2 ];
2719
+ for (int j = 0 ; j < 2 ; ++ j ) {
2720
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2721
+ __m128i bx = bytesFromNibbles ( x [i ].qs + 8 * j );
2722
+ __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 * j ));
2723
+
2724
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2725
+ const __m128i off = _mm_set1_epi8 ( 8 );
2726
+ bx = _mm_sub_epi8 ( bx , off );
2727
+
2728
+ // Get absolute values of x vectors
2729
+ const __m128i ax = _mm_sign_epi8 (bx , bx );
2730
+
2731
+ // Sign the values of the y vectors
2732
+ const __m128i sy = _mm_sign_epi8 (by , bx );
2733
+
2734
+ // Perform multiplication and create 16-bit values
2735
+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2736
+
2737
+ const __m128i ones = _mm_set1_epi16 (1 );
2738
+ i32 [j ] = _mm_madd_epi16 (ones , dot );
2739
+ }
2740
+
2741
+ // Convert int32_t to float
2742
+ __m256 p = _mm256_cvtepi32_ps ( _mm256_set_m128i ( i32 [0 ], i32 [1 ] ));
2743
+ // Apply the scale, and accumulate
2744
+ acc = _mm256_add_ps (_mm256_mul_ps ( d , p ), acc );
2745
+ }
2746
+
2747
+ // Return horizontal sum of the acc vector
2748
+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
2749
+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2750
+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2751
+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2752
+
2753
+ sumf = _mm_cvtss_f32 ( res );
2590
2754
#else
2591
2755
// scalar
2592
2756
for (int i = 0 ; i < nb ; i ++ ) {
2593
2757
const float d0 = x [i ].d ;
2594
2758
const float d1 = y [i ].d ;
2595
2759
2596
2760
const uint8_t * restrict p0 = x [i ].qs ;
2597
- const uint8_t * restrict p1 = y [i ].qs ;
2761
+ const int8_t * restrict p1 = y [i ].qs ;
2598
2762
2599
2763
int sumi = 0 ;
2600
2764
for (int j = 0 ; j < QK /2 ; j ++ ) {
@@ -2603,10 +2767,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2603
2767
const int i0 = (int8_t ) (v0 & 0xf ) - 8 ;
2604
2768
const int i1 = (int8_t ) (v0 >> 4 ) - 8 ;
2605
2769
2606
- const int i2 = (int ) p1 [2 * j + 0 ] - 128 ;
2607
- const int i3 = (int ) p1 [2 * j + 1 ] - 128 ;
2608
-
2609
- /*printf("dot product: i0=%4d i1=%4d i2=%4d i3=%4d\n", i0, i1, i2, i3);*/
2770
+ const int i2 = p1 [2 * j + 0 ];
2771
+ const int i3 = p1 [2 * j + 1 ];
2610
2772
2611
2773
sumi += i0 * i2 + i1 * i3 ;
2612
2774
}
@@ -10171,7 +10333,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10171
10333
cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ]* (node -> src0 -> ne [0 ]* node -> src0 -> ne [1 ]);
10172
10334
} else
10173
10335
#endif
10174
- cur = GGML_TYPE_SIZE [GGML_TYPE_Q8_0 ]* ggml_nelements (node -> src1 )/GGML_BLCK_SIZE [GGML_TYPE_Q8_0 ];
10336
+ {
10337
+ cur = GGML_TYPE_SIZE [GGML_TYPE_Q8_0 ]* ggml_nelements (node -> src1 )/GGML_BLCK_SIZE [GGML_TYPE_Q8_0 ];
10338
+ }
10175
10339
} else {
10176
10340
GGML_ASSERT (false);
10177
10341
}
0 commit comments