Skip to content

Add partial AVX512 Linux support for dot product on 4-bit quantized values #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,38 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
AVX512F_M := $(shell grep "avx512f " /proc/cpuinfo)
ifneq (,$(findstring avx512f,$(AVX512F_M)))
CFLAGS += -mavx512f
endif
AVX512BW_M := $(shell grep "avx512bw " /proc/cpuinfo)
ifneq (,$(findstring avx512bw,$(AVX512BW_M)))
CFLAGS += -mavx512bw
endif
AVX512DQ_M := $(shell grep "avx512dq " /proc/cpuinfo)
ifneq (,$(findstring avx512dq,$(AVX512DQ_M)))
CFLAGS += -mavx512dq
endif
AVX512VL_M := $(shell grep "avx512vl " /proc/cpuinfo)
ifneq (,$(findstring avx512vl,$(AVX512VL_M)))
CFLAGS += -mavx512vl
endif
AVX512CD_M := $(shell grep "avx512cd " /proc/cpuinfo)
ifneq (,$(findstring avx512cd,$(AVX512CD_M)))
CFLAGS += -mavx512cd
endif
AVX512ER_M := $(shell grep "avx512er " /proc/cpuinfo)
ifneq (,$(findstring avx512er,$(AVX512ER_M)))
CFLAGS += -mavx512er
endif
AVX512IFMA_M := $(shell grep "avx512ifma " /proc/cpuinfo)
ifneq (,$(findstring avx512ifma,$(AVX512IFMA_M)))
CFLAGS += -mavx512ifma
endif
AVX512PF_M := $(shell grep "avx512pf " /proc/cpuinfo)
ifneq (,$(findstring avx512pf,$(AVX512PF_M)))
CFLAGS += -mavx512pf
endif
else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
Expand Down
92 changes: 89 additions & 3 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1412,10 +1412,96 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
#else
#error "not implemented for QK"
#endif
#elif defined(__AVX2__)
#elif defined(__AVX512F__)

inline __m256i bytesFromNibbles( const uint8_t* rsi ){
// Load 16 bytes from memory
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );

// Expand bytes into uint16_t values
__m256i bytes = _mm256_cvtepu8_epi16( tmp );

// Unpack values into individual bytes
const __m256i lowMask = _mm256_set1_epi8( 0xF );
__m256i high = _mm256_andnot_si256( lowMask, bytes );
__m256i low = _mm256_and_si256( lowMask, bytes );
high = _mm256_slli_epi16( high, 4 );
bytes = _mm256_or_si256( low, high );
return bytes;
}

inline __m512 process_one_block(
__m512 acc,
const uint8_t * pd0,
const uint8_t * pd1,
const uint8_t * pb0,
const uint8_t * pb1,
int i
) {
const float * d0_0 = (const float *) (pd0 + i*bs);
const float * d1_0 = (const float *) (pd1 + i*bs);

const uint8_t * restrict p0 = pb0 + (i+0)*bs;
const uint8_t * restrict p1 = pb1 + (i+0)*bs;

// Compute combined scale for the block
float scaleScalar = d0_0[0] * d1_0[0];
__m512 scale = _mm512_set1_ps( scaleScalar );

__m256i bx = bytesFromNibbles( p0 );
__m256i by = bytesFromNibbles( p1 );

// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8( bx, off );
by = _mm256_sub_epi8( by, off );

// Sign-extend 16 signed bytes into int16_t
__m512i x32 = _mm512_cvtepi8_epi16( bx );
__m512i y32 = _mm512_cvtepi8_epi16( by );
// Compute products of int16_t integers, add pairwise
__m512i i64 = _mm512_madd_epi16( x32, y32 );

// Convert int32_t to float
__m512 p = _mm512_cvtepi32_ps( i64 );
// Apply the scale, and accumulate
return _mm512_fmadd_ps( scale, p, acc );
}

#if QK == 32
const size_t countBlocks = nb;
// Initialize accumulator with zeros
__m512 acc0 = _mm512_setzero_ps();
__m512 acc1 = _mm512_setzero_ps();

const int superblock_size = 8;
const int superblock_count = nb / superblock_size;
const int remainder = nb % superblock_size;

for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
int i = superblock_ix * superblock_size;

acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+0 );
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+1 );
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+2 );
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+3 );
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+4 );
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+5 );
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+6 );
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+7 );
}

// Remainders
for (int i = superblock_count * superblock_size; i < nb; ++i) {
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i );
}

// Horizontal sum of all lanes of the accumulator
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
#else
#error "not implemented for QK"
#endif
#elif defined(__AVX2__)
#if QK == 32
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

Expand Down Expand Up @@ -5485,7 +5571,7 @@ static void ggml_compute_forward_rms_norm_f32(
mean /= ne00;

float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);

memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
Expand Down