Skip to content

Commit 2c4f9b6

Browse files
swggerganov
authored andcommitted
Q8: use int8_t, AVX/AVX2 optimizations
1 parent 19e7a65 commit 2c4f9b6

File tree

1 file changed

+190
-26
lines changed

1 file changed

+190
-26
lines changed

ggml.c

Lines changed: 190 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 bloc
592592

593593
typedef struct {
594594
float d; // delta
595-
uint8_t qs[QK]; // nibbles / quants
595+
int8_t qs[QK]; // quants
596596
} block_q8_0;
597597
static_assert(sizeof(block_q8_0) == sizeof(float) + QK, "wrong q8_0 block size/padding");
598598

@@ -1069,9 +1069,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
10691069

10701070
for (int l = 0; l < QK; ++l) {
10711071
const float v = x[i*QK + l]*id;
1072-
const uint8_t vi = (int8_t)roundf(v) + 128;
1073-
1074-
y[i].qs[l] = vi;
1072+
y[i].qs[l] = roundf(v);
10751073
}
10761074
}
10771075
}
@@ -1104,15 +1102,99 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
11041102

11051103
for (int l = 0; l < 8; l++) {
11061104
const float32x4_t v = vmulq_n_f32(srcv[l], id);
1107-
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(128.5f));
1108-
const int32x4_t vi = vcvtq_s32_f32(vf);
1105+
//TODO: rounding
1106+
const int32x4_t vi = vcvtq_s32_f32(v);
11091107

11101108
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
11111109
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
11121110
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
11131111
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
11141112
}
11151113
}
1114+
#elif defined(__AVX2__) || defined(__AVX__)
1115+
for (int i = 0; i < nb; i++) {
1116+
// Load elements into 4 AVX vectors
1117+
__m256 v0 = _mm256_loadu_ps( x );
1118+
__m256 v1 = _mm256_loadu_ps( x + 8 );
1119+
__m256 v2 = _mm256_loadu_ps( x + 16 );
1120+
__m256 v3 = _mm256_loadu_ps( x + 24 );
1121+
x += 32;
1122+
1123+
// Compute max(abs(e)) for the block
1124+
const __m256 signBit = _mm256_set1_ps( -0.0f );
1125+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1126+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1127+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1128+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1129+
1130+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1131+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1132+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1133+
const float maxScalar = _mm_cvtss_f32( max4 );
1134+
1135+
// Quantize these floats
1136+
const float d = maxScalar / 127.f;
1137+
y[i].d = d;
1138+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1139+
const __m256 mul = _mm256_set1_ps( id );
1140+
1141+
// Apply the multiplier
1142+
v0 = _mm256_mul_ps( v0, mul );
1143+
v1 = _mm256_mul_ps( v1, mul );
1144+
v2 = _mm256_mul_ps( v2, mul );
1145+
v3 = _mm256_mul_ps( v3, mul );
1146+
1147+
// Round to nearest integer
1148+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1149+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1150+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1151+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1152+
1153+
// Convert floats to integers
1154+
__m256i i0 = _mm256_cvtps_epi32( v0 );
1155+
__m256i i1 = _mm256_cvtps_epi32( v1 );
1156+
__m256i i2 = _mm256_cvtps_epi32( v2 );
1157+
__m256i i3 = _mm256_cvtps_epi32( v3 );
1158+
1159+
#if defined(__AVX2__)
1160+
// Convert int32 to int16
1161+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1162+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1163+
// Convert int16 to int8
1164+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1165+
1166+
// We got our precious signed bytes, but the order is now wrong
1167+
// These AVX2 pack instructions process 16-byte pieces independently
1168+
// The following instruction is fixing the order
1169+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1170+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
1171+
1172+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
1173+
#else
1174+
// Since we don't have in AVX some necessary functions,
1175+
// we split the registers in half and call AVX2 analogs from SSE
1176+
__m128i ni0 = _mm256_castsi256_si128( i0 );
1177+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
1178+
__m128i ni2 = _mm256_castsi256_si128( i1 );
1179+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
1180+
__m128i ni4 = _mm256_castsi256_si128( i2 );
1181+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
1182+
__m128i ni6 = _mm256_castsi256_si128( i3 );
1183+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
1184+
1185+
// Convert int32 to int16
1186+
ni0 = _mm_packs_epi32( ni0, ni1 );
1187+
ni2 = _mm_packs_epi32( ni2, ni3 );
1188+
ni4 = _mm_packs_epi32( ni4, ni5 );
1189+
ni6 = _mm_packs_epi32( ni6, ni7 );
1190+
// Convert int16 to int8
1191+
ni0 = _mm_packs_epi16( ni0, ni2 );
1192+
ni4 = _mm_packs_epi16( ni4, ni6 );
1193+
1194+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1195+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1196+
#endif
1197+
}
11161198
#else
11171199
// scalar
11181200
quantize_row_q8_0_reference(x, y, k);
@@ -2517,7 +2599,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25172599

25182600
const uint8x16_t m4b = vdupq_n_u8(0xf);
25192601
const int8x16_t s8b = vdupq_n_s8(0x8);
2520-
const uint8x16_t u128b = vdupq_n_u8(128);
25212602

25222603
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
25232604
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2535,21 +2616,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25352616
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
25362617

25372618
// load y
2538-
const uint8x16_t v1_0l = vld1q_u8(y0->qs);
2539-
const uint8x16_t v1_0h = vld1q_u8(y0->qs + 16);
2540-
const uint8x16_t v1_1l = vld1q_u8(y1->qs);
2541-
const uint8x16_t v1_1h = vld1q_u8(y1->qs + 16);
2619+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
2620+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2621+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
2622+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
25422623

25432624
// interleave
2544-
const uint8x16_t v1_0lz = vuzp1q_u8(v1_0l, v1_0h);
2545-
const uint8x16_t v1_0hz = vuzp2q_u8(v1_0l, v1_0h);
2546-
const uint8x16_t v1_1lz = vuzp1q_u8(v1_1l, v1_1h);
2547-
const uint8x16_t v1_1hz = vuzp2q_u8(v1_1l, v1_1h);
2548-
2549-
const int8x16_t v1_0ls = vreinterpretq_s8_u8(vsubq_u8(v1_0lz, u128b));
2550-
const int8x16_t v1_0hs = vreinterpretq_s8_u8(vsubq_u8(v1_0hz, u128b));
2551-
const int8x16_t v1_1ls = vreinterpretq_s8_u8(vsubq_u8(v1_1lz, u128b));
2552-
const int8x16_t v1_1hs = vreinterpretq_s8_u8(vsubq_u8(v1_1hz, u128b));
2625+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2626+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2627+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2628+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
25532629

25542630
#if defined(__ARM_FEATURE_DOTPROD)
25552631
// dot product into int32x4_t
@@ -2587,14 +2663,102 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25872663
}
25882664

25892665
sumf = sum0 + sum1;
2666+
#elif defined(__AVX2__)
2667+
// Initialize accumulator with zeros
2668+
__m256 acc = _mm256_setzero_ps();
2669+
2670+
// Main loop
2671+
for (int i = 0; i < nb; ++i) {
2672+
/* Compute combined scale for the block */
2673+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2674+
2675+
__m256i bx = bytesFromNibbles(x[i].qs);
2676+
2677+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2678+
const __m256i off = _mm256_set1_epi8( 8 );
2679+
bx = _mm256_sub_epi8( bx, off );
2680+
2681+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2682+
2683+
// Get absolute values of x vectors
2684+
const __m256i ax = _mm256_sign_epi8(bx, bx);
2685+
2686+
// Sign the values of the y vectors
2687+
const __m256i sy = _mm256_sign_epi8(by, bx);
2688+
2689+
// Perform multiplication and create 16-bit values
2690+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2691+
2692+
const __m256i ones = _mm256_set1_epi16(1);
2693+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
2694+
2695+
/* Convert to vectore of 8 int32_t to 8 floats */
2696+
__m256 q = _mm256_cvtepi32_ps( xy_q );
2697+
2698+
/* Multiply q with scale and accumulate */
2699+
acc = _mm256_fmadd_ps( d, q, acc );
2700+
}
2701+
2702+
// Return horizontal sum of the acc vector
2703+
__m128 res = _mm256_extractf128_ps( acc, 1 );
2704+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2705+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2706+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2707+
2708+
sumf = _mm_cvtss_f32( res );
2709+
#elif defined(__AVX__)
2710+
// Initialize accumulator with zeros
2711+
__m256 acc = _mm256_setzero_ps();
2712+
2713+
// Main loop
2714+
for (int i = 0; i < nb; ++i) {
2715+
// Compute combined scale for the block
2716+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2717+
2718+
__m128i i32[2];
2719+
for (int j = 0; j < 2; ++j) {
2720+
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2721+
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2722+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2723+
2724+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2725+
const __m128i off = _mm_set1_epi8( 8 );
2726+
bx = _mm_sub_epi8( bx, off );
2727+
2728+
// Get absolute values of x vectors
2729+
const __m128i ax = _mm_sign_epi8(bx, bx);
2730+
2731+
// Sign the values of the y vectors
2732+
const __m128i sy = _mm_sign_epi8(by, bx);
2733+
2734+
// Perform multiplication and create 16-bit values
2735+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2736+
2737+
const __m128i ones = _mm_set1_epi16(1);
2738+
i32[j] = _mm_madd_epi16(ones, dot);
2739+
}
2740+
2741+
// Convert int32_t to float
2742+
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2743+
// Apply the scale, and accumulate
2744+
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2745+
}
2746+
2747+
// Return horizontal sum of the acc vector
2748+
__m128 res = _mm256_extractf128_ps( acc, 1 );
2749+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2750+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2751+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2752+
2753+
sumf = _mm_cvtss_f32( res );
25902754
#else
25912755
// scalar
25922756
for (int i = 0; i < nb; i++) {
25932757
const float d0 = x[i].d;
25942758
const float d1 = y[i].d;
25952759

25962760
const uint8_t * restrict p0 = x[i].qs;
2597-
const uint8_t * restrict p1 = y[i].qs;
2761+
const int8_t * restrict p1 = y[i].qs;
25982762

25992763
int sumi = 0;
26002764
for (int j = 0; j < QK/2; j++) {
@@ -2603,10 +2767,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
26032767
const int i0 = (int8_t) (v0 & 0xf) - 8;
26042768
const int i1 = (int8_t) (v0 >> 4) - 8;
26052769

2606-
const int i2 = (int) p1[2*j + 0] - 128;
2607-
const int i3 = (int) p1[2*j + 1] - 128;
2608-
2609-
/*printf("dot product: i0=%4d i1=%4d i2=%4d i3=%4d\n", i0, i1, i2, i3);*/
2770+
const int i2 = p1[2*j + 0];
2771+
const int i3 = p1[2*j + 1];
26102772

26112773
sumi += i0*i2 + i1*i3;
26122774
}
@@ -10171,7 +10333,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1017110333
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
1017210334
} else
1017310335
#endif
10174-
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
10336+
{
10337+
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
10338+
}
1017510339
} else {
1017610340
GGML_ASSERT(false);
1017710341
}

0 commit comments

Comments
 (0)