-
Notifications
You must be signed in to change notification settings - Fork 0
Faster q3_0 implementation, using two planes #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
f54da1f
21acb58
394a3c1
f7874df
b304c4d
e893d33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -606,12 +606,12 @@ typedef struct { | |
static_assert(sizeof(block_q2_0) == sizeof(ggml_fp16_t) + QK2_0 / 4, "wrong q2_0 size/padding"); | ||
|
||
#define QK3_0 16 | ||
typedef union { | ||
struct { | ||
uint16_t pad[3]; | ||
ggml_fp16_t d; | ||
}; | ||
uint64_t qs; | ||
typedef struct { | ||
ggml_fp16_t d; | ||
// Instead of representing q3_0 as a packed format "...210210210210", | ||
// represent it as two planes: "...10101010" and "...2222" | ||
uint16_t qhi; // The highest bit of each 3-bit number, packed together | ||
uint32_t qlo; // The low 2-bits of each 3-bit number, packed together | ||
} block_q3_0; | ||
static_assert(sizeof(block_q3_0) == sizeof(ggml_fp16_t) + QK3_0 * 3 / 8, "wrong q3_0 size/padding"); | ||
|
||
|
@@ -691,17 +691,20 @@ static void quantize_row_q3_0(const float * restrict x, block_q3_0 * restrict y, | |
const float d = max / -4; | ||
const float id = d ? 1.0f/d : 0.0f; | ||
|
||
uint64_t qs = 0; | ||
uint32_t lo = 0; | ||
uint16_t hi = 0; | ||
|
||
for (int l = 0; l < QK3_0; l++) { | ||
const float v = x[i*QK3_0 + l]*id; | ||
for (int l = 0; l < 16; l++) { | ||
const float v = x[i*16 + l]*id; | ||
const uint8_t vi = MIN(7, (int8_t)roundf(v) + 4); | ||
assert(vi < 8); | ||
qs |= (uint64_t)vi << (l*3); | ||
lo |= (vi & 3) << (l * 2); | ||
hi |= ((vi >> 2) & 1) << l; | ||
} | ||
|
||
y[i].qs = qs; | ||
y[i].d = GGML_FP32_TO_FP16(d); // overwrite unused part of uint64_t qs | ||
y[i].d = GGML_FP32_TO_FP16(d); | ||
y[i].qlo = lo; | ||
y[i].qhi = hi; | ||
} | ||
} | ||
|
||
|
@@ -1335,13 +1338,15 @@ static void dequantize_row_q3_0(const void * restrict vx, float * restrict y, in | |
|
||
for (int i = 0; i < nb; i++) { | ||
const float d = GGML_FP16_TO_FP32(x[i].d); | ||
uint64_t qs = x[i].qs; | ||
for (int l = 0; l < QK3_0; l++) { | ||
const int8_t vi = qs & 7; | ||
uint_fast32_t lo = x[i].qlo; | ||
uint_fast32_t hi = x[i].qhi << 2; | ||
for (int l = 0; l < 16; l++) { | ||
const int8_t vi = (lo & 3) | (hi & 4); | ||
const float v = (vi - 4)*d; | ||
y[i*QK3_0 + l] = v; | ||
assert(!isnan(y[i*QK3_0 + l])); | ||
qs >>= 3; | ||
y[i*16 + l] = v; | ||
assert(!isnan(y[i*16 + l])); | ||
lo >>= 2; | ||
hi >>= 1; | ||
} | ||
} | ||
} | ||
|
@@ -2193,6 +2198,39 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t | |
*s = sumf; | ||
} | ||
|
||
#if __AVX2__ || __AVX512F__ | ||
// Computes the dot product of signed 8-bit integers packed into 256-bit vectors, | ||
// converting the result to 32-bit floats packed into a 256-bit vector. | ||
static inline __m256 dotMul(__m256i bx, __m256i by) { | ||
# if __AVXVNNIINT8__ | ||
// Perform multiplication and sum to 32-bit values | ||
const __m256i i32 = _mm256_dpbssd_epi32(bx, by, _mm256_setzero_si256()); | ||
# else | ||
// Get absolute values of x vectors | ||
const __m256i ax = _mm256_sign_epi8(bx, bx); | ||
// Sign the values of the y vectors | ||
const __m256i sy = _mm256_sign_epi8(by, bx); | ||
// Perform multiplication and create 16-bit values | ||
const __m256i dot = _mm256_maddubs_epi16(ax, sy); | ||
|
||
// Convert int16_t to int32_t by adding pairwise | ||
const __m256i ones = _mm256_set1_epi16(1); | ||
const __m256i i32 = _mm256_madd_epi16(ones, dot); | ||
# endif | ||
// Convert int32_t to float | ||
return _mm256_cvtepi32_ps(i32); | ||
} | ||
|
||
// Return horizontal sum of 32-bit floats packed into a 256-bit vector. | ||
static inline float horizontalSum(__m256 acc) { | ||
__m128 res = _mm256_extractf128_ps(acc, 1); | ||
res = _mm_add_ps(res, _mm256_castps256_ps128(acc)); | ||
res = _mm_add_ps(res, _mm_movehl_ps(res, res)); | ||
res = _mm_add_ss(res, _mm_movehdup_ps(res)); | ||
return _mm_cvtss_f32(res); | ||
} | ||
#endif | ||
|
||
static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { | ||
assert(n % QK2_0 == 0); | ||
const int nb = n / QK2_0; | ||
|
@@ -2222,30 +2260,15 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * | |
// Load y vector | ||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); | ||
|
||
// Get absolute values of x vectors | ||
const __m256i ax = _mm256_sign_epi8(bx, bx); | ||
// Sign the values of the y vectors | ||
const __m256i sy = _mm256_sign_epi8(by, bx); | ||
// Perform multiplication and create 16-bit values | ||
const __m256i dot = _mm256_maddubs_epi16(ax, sy); | ||
|
||
// Convert int16_t to int32_t by adding pairwise | ||
const __m256i ones = _mm256_set1_epi16(1); | ||
__m256i i32 = _mm256_madd_epi16(ones, dot); | ||
|
||
// Convert int32_t to float | ||
__m256 p = _mm256_cvtepi32_ps(i32); | ||
// Do the product: | ||
__m256 p = dotMul(bx, by); | ||
|
||
// Apply the scale, and accumulate | ||
acc = _mm256_fmadd_ps(scale, p, acc); | ||
} | ||
|
||
// Return horizontal sum of the acc vector | ||
__m128 res = _mm256_extractf128_ps(acc, 1); | ||
res = _mm_add_ps(res, _mm256_castps256_ps128(acc)); | ||
res = _mm_add_ps(res, _mm_movehl_ps(res, res)); | ||
res = _mm_add_ss(res, _mm_movehdup_ps(res)); | ||
sumf = _mm_cvtss_f32(res); | ||
sumf = horizontalSum(acc); | ||
#else | ||
for (int i = 0; i < nb; i++) { | ||
const float d0 = GGML_FP16_TO_FP32(x[i].d); | ||
|
@@ -2270,6 +2293,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * | |
*s = sumf; | ||
} | ||
|
||
// Lookup table used to convert q3_0 to SIMD vectors. | ||
// Expands the bits of an 8-bit value into a 64 bit result, turning each bit into a byte. | ||
// A zero bit turns into 0xFC, while a one bit turns into 0x00. | ||
#define B0(n) 0x ## n | ||
#define B1(n) B0(n ## FC), B0(n ## 00) | ||
#define B2(n) B1(n ## FC), B1(n ## 00) | ||
#define B3(n) B2(n ## FC), B2(n ## 00) | ||
#define B4(n) B3(n ## FC), B3(n ## 00) | ||
#define B5(n) B4(n ## FC), B4(n ## 00) | ||
#define B6(n) B5(n ## FC), B5(n ## 00) | ||
#define B7(n) B6(n ## FC), B6(n ## 00) | ||
#define B8( ) B7( FC), B7( 00) | ||
static const uint64_t ggml_q3_table[256] = { B8() }; | ||
|
||
static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { | ||
assert(n % QK3_0 == 0); | ||
const int nb = n / QK3_0; | ||
|
@@ -2282,103 +2319,54 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void * | |
|
||
#if defined(__AVX2__) | ||
// Initialize accumulator with zeros | ||
__m128 acc = _mm_setzero_ps(); | ||
__m256 acc = _mm256_setzero_ps(); | ||
|
||
for (int i = 0; i < nb/2; i++) { | ||
const __m128 scale_y = _mm_set1_ps(y[i].d); | ||
for (int u = 0; u < 2; u++) { // let the compiler unroll this | ||
// Compute combined scale for the block | ||
const __m128 scale_x = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+u].d)); | ||
const __m128 scale = _mm_mul_ps(scale_x, scale_y); | ||
|
||
__m256i bxx = _mm256_set1_epi64x(x[i*2+u].qs); | ||
|
||
// legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale | ||
|
||
// shift the copies to be able to reach all values | ||
// 255 192 128 64 0 | ||
// | | | | | ||
// sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in | ||
// sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left | ||
// _______________________sssssfedcba98765432__________________________________________ shift right | ||
// sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out | ||
// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ | ||
// e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0 | ||
const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0); | ||
const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64); | ||
bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r)); | ||
|
||
// add to itself in masked places to shift some values left one bit | ||
// 127 64 0 | ||
// | | | | | | | | | | | | | | | | | ||
// ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in | ||
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask | ||
// _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked | ||
// .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum | ||
// | ||
// 255 192 128 | ||
// | | | | | | | | | | | | | | | | | ||
// ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in | ||
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask | ||
// _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked | ||
// .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum | ||
const __m256i doublemask = _mm256_set1_epi64x(0x078000078000); | ||
bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx)); | ||
|
||
// collect 16 bytes from 256 into 128 bits | ||
const __m256i shufmask = _mm256_set_epi8( | ||
5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1, | ||
-1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0); | ||
bxx = _mm256_shuffle_epi8(bxx, shufmask); | ||
|
||
__m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1)); | ||
|
||
const __m128i mask = _mm_set1_epi8(7); | ||
bx = _mm_and_si128(mask, bx); | ||
|
||
const __m128i off = _mm_set1_epi8(4); | ||
bx = _mm_sub_epi8(bx, off); | ||
|
||
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + u*QK3_0)); | ||
__m256i bx = bytesFromCrumbs(x[i*2+1].qlo, x[i*2].qlo); | ||
|
||
// Get absolute values of x vectors | ||
const __m128i ax = _mm_sign_epi8(bx, bx); | ||
// Sign the values of the y vectors | ||
const __m128i sy = _mm_sign_epi8(by, bx); | ||
// Perform multiplication and create 16-bit values | ||
const __m128i dot = _mm_maddubs_epi16(ax, sy); | ||
__m256i const bxhi = _mm256_set_epi64x( | ||
ggml_q3_table[x[i*2+1].qhi >> 8], ggml_q3_table[x[i*2+1].qhi & 0xFF], | ||
ggml_q3_table[x[i*2+0].qhi >> 8], ggml_q3_table[x[i*2+0].qhi & 0xFF]); | ||
|
||
// Convert int16_t to int32_t by adding pairwise | ||
const __m128i ones = _mm_set1_epi16(1); | ||
__m128i i32 = _mm_madd_epi16(dot, ones); | ||
// OR the high bits (which also handles the sign): | ||
bx = _mm256_or_si256(bx, bxhi); | ||
|
||
// Convert int32_t to float | ||
const __m128 p = _mm_cvtepi32_ps(i32); | ||
// Compute combined scale for the block | ||
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+0].d)); | ||
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+1].d)); | ||
__m256 scale = _mm256_set_m128(scale_hi, scale_lo); | ||
scale = _mm256_mul_ps(scale, _mm256_broadcast_ss(&y[i].d)); | ||
|
||
// Apply the scale, and accumulate | ||
acc = _mm_fmadd_ps(scale, p, acc); | ||
} | ||
// Load y vector | ||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); | ||
|
||
// Do the product, | ||
__m256 p = dotMul(bx, by); | ||
|
||
// Apply the scale, and accumulate | ||
acc = _mm256_fmadd_ps(scale, p, acc); | ||
} | ||
|
||
// Return horizontal sum of the acc vector | ||
__m128 res = _mm_add_ps(acc, _mm_movehl_ps(acc, acc)); | ||
res = _mm_add_ss(res, _mm_movehdup_ps(res)); | ||
sumf = _mm_cvtss_f32(res); | ||
sumf = horizontalSum(acc); | ||
#else | ||
for (int i = 0; i < nb; i++) { | ||
const float d0 = GGML_FP16_TO_FP32(x[i].d); | ||
const float d1 = y[i/2].d; | ||
|
||
uint64_t qs0 = x[i].qs; | ||
uint_fast32_t lo0 = x[i].qlo; | ||
uint_fast32_t hi0 = x[i].qhi << 2; | ||
const int8_t * restrict p1 = y[i/2].qs + (i%2)*QK3_0; | ||
|
||
int sumi = 0; | ||
for (int j = 0; j < QK3_0; j++) { | ||
const int8_t i0 = (int8_t)(qs0 & 7) - 4; | ||
const int_fast16_t i1 = p1[j]; | ||
for (int l = 0; l < 16; l++) { | ||
const int8_t i0 = (int8_t)((lo0 & 3) | ((hi0 & 4) - 4)); | ||
const int_fast16_t i1 = p1[l]; | ||
|
||
sumi += i0 * i1; | ||
|
||
qs0 >>= 3; | ||
lo0 >>= 2; | ||
hi0 >>= 1; | ||
} | ||
sumf += d0 * d1 * sumi; | ||
} | ||
|
@@ -11622,19 +11610,20 @@ size_t ggml_quantize_q2_0(const float * src, void * dst, int n, int k, int64_t h | |
|
||
size_t ggml_quantize_q3_0(const float * src, void * dst, int n, int k, int64_t hist[1<<3]) { | ||
assert(k % QK3_0 == 0); | ||
const int nb = k / QK3_0; | ||
|
||
for (int j = 0; j < n; j += k) { | ||
block_q3_0 * restrict y = (block_q3_0 *)dst + j/QK3_0; | ||
|
||
quantize_row_q3_0(src + j, y, k); | ||
|
||
for (int i = 0; i < nb; i++) { | ||
uint64_t qs = y[i].qs; | ||
for (int l = 0; l < QK3_0; l++) { | ||
const int8_t vi = qs & 7; | ||
for (int i = 0; i < 16; i++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pubby : I'm confused by this... why should we have two nested loops that go up to 16? It's getting late for me, possibly you're right about this. But then q2 would also be wrong? I'll look into it tomorrow. Also, you changed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that line's a mistake. I believe it should be The
The 16 number referred to the bit count per I scrapped that version though, so I'll revert the changes and use |
||
uint_fast32_t lo = y[i].qlo; | ||
uint_fast32_t hi = y[i].qhi << 2; | ||
for (int l = 0; l < 16; l++) { | ||
int8_t vi = (lo & 3) | (hi & 4); | ||
hist[vi]++; | ||
qs >>= 3; | ||
lo >>= 2; | ||
hi >>= 1; | ||
} | ||
} | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.