Skip to content

Commit c5e3b52

Browse files
committed
ggml: dynamic x86_64 feature detection for FP32 <-> FP16/BF16 conversion
1 parent 13be08d commit c5e3b52

File tree

1 file changed

+177
-46
lines changed

1 file changed

+177
-46
lines changed

ggml/src/ggml.c

Lines changed: 177 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
#include <TargetConditionals.h>
4343
#endif
4444

45+
#if defined(__x86_64__)
46+
#include <immintrin.h>
47+
#if defined(_MSC_VER)
48+
# include <intrin.h>
49+
#endif
50+
#endif
51+
4552
#if defined(_WIN32)
4653
#define WIN32_LEAN_AND_MEAN
4754
#ifndef NOMINMAX
@@ -382,62 +389,186 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
382389
}
383390
}
384391

385-
// FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
386-
// currently, the ggml_cpu_has_* functions are entirely compile-time
387-
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
388-
int64_t i = 0;
389-
#if defined(__F16C__)
390-
//if (ggml_cpu_has_f16c()) {
391-
for (; i + 7 < n; i += 8) {
392-
__m256 x_vec = _mm256_loadu_ps(x + i);
393-
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
394-
_mm_storeu_si128((__m128i *)(y + i), y_vec);
395-
}
396-
for(; i + 3 < n; i += 4) {
397-
__m128 x_vec = _mm_loadu_ps(x + i);
398-
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
399-
_mm_storel_epi64((__m128i *)(y + i), y_vec);
400-
}
401-
//}
392+
#if defined(__x86_64__)
393+
394+
#if defined(_MSC_VER)
395+
#include <intrin.h>
396+
static void cpuid(int leaf, int subleaf, int *eax, int *ebx, int *ecx, int *edx) {
397+
int regs[4];
398+
__cpuidex(regs, leaf, subleaf);
399+
*eax = regs[0];
400+
*ebx = regs[1];
401+
*ecx = regs[2];
402+
*edx = regs[3];
403+
}
404+
#elif defined(__GNUC__) || defined(__clang__)
405+
static void cpuid(int leaf, int subleaf, int *eax, int *ebx, int *ecx, int *edx) {
406+
__asm__ volatile (
407+
"cpuid"
408+
: "=a"(*eax), "=b"(*ebx), "=c"(*ecx), "=d"(*edx)
409+
: "a"(leaf), "c"(subleaf)
410+
);
411+
}
412+
#else
413+
#error Unsupported compiler
402414
#endif
403-
for (; i < n; i++) {
415+
416+
static bool x86_64_supports_f16c(void) {
417+
int eax, ebx, ecx, edx;
418+
cpuid(1, 0, &eax, &ebx, &ecx, &edx);
419+
return (ecx & (1 << 29)) != 0;
420+
}
421+
422+
static bool x86_64_supports_avx2(void) {
423+
int eax, ebx, ecx, edx;
424+
cpuid(0, 0, &eax, &ebx, &ecx, &edx);
425+
if (eax < 7)
426+
return 0;
427+
cpuid(7, 0, &eax, &ebx, &ecx, &edx);
428+
return (ebx & (1 << 5)) != 0;
429+
}
430+
431+
static bool x86_64_supports_avx512f(void) {
432+
int eax, ebx, ecx, edx;
433+
cpuid(0, 0, &eax, &ebx, &ecx, &edx);
434+
if (eax < 7) return 0;
435+
cpuid(7, 0, &eax, &ebx, &ecx, &edx);
436+
return (ebx & (1 << 16)) != 0;
437+
}
438+
439+
static struct ggml_type_traits type_traits[GGML_TYPE_COUNT];
440+
441+
static inline void ggml_fp32_to_fp16_generic(const float * x, ggml_fp16_t * y, int64_t n) {
442+
for (int64_t i = 0; i < n; i++) {
404443
y[i] = GGML_FP32_TO_FP16(x[i]);
405444
}
406445
}
407446

408-
void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
447+
static inline void __attribute__((target("f16c"))) ggml_fp32_to_fp16_row_f16c(const float * x, ggml_fp16_t * y, int64_t n) {
409448
int64_t i = 0;
410-
#if defined(__AVX512F__)
411-
//if (ggml_cpu_has_avx512()) {
412-
for (; i + 16 <= n; i += 16) {
413-
_mm512_storeu_ps(y + i,
414-
_mm512_castsi512_ps(
415-
_mm512_slli_epi32(
416-
_mm512_cvtepu16_epi32(
417-
_mm256_loadu_si256(
418-
(const __m256i *)(x + i))),
419-
16)));
420-
}
421-
//}
422-
#endif
423-
#if defined(__AVX2__)
424-
//if (ggml_cpu_has_avx2()) {
425-
for (; i + 8 <= n; i += 8) {
426-
_mm256_storeu_ps(y + i,
427-
_mm256_castsi256_ps(
428-
_mm256_slli_epi32(
429-
_mm256_cvtepu16_epi32(
430-
_mm_loadu_si128(
431-
(const __m128i *)(x + i))),
432-
16)));
433-
}
434-
//}
449+
for (; i + 7 < n; i += 8) {
450+
__m256 x_vec = _mm256_loadu_ps(x + i);
451+
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
452+
_mm_storeu_si128((__m128i *)(y + i), y_vec);
453+
}
454+
for (; i + 3 < n; i += 4) {
455+
__m128 x_vec = _mm_loadu_ps(x + i);
456+
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
457+
_mm_storel_epi64((__m128i *)(y + i), y_vec);
458+
}
459+
ggml_fp32_to_fp16_generic(x + i, y + i, n - i);
460+
}
461+
462+
static inline void __attribute__((target("avx512f"))) ggml_fp32_to_fp16_row_avx512f(const float * x, ggml_fp16_t * y, int64_t n) {
463+
int64_t i = 0;
464+
for (; i + 15 < n; i += 16) {
465+
__m512 x_vec = _mm512_loadu_ps(x + i);
466+
__m256i y_vec = _mm512_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
467+
_mm256_storeu_si256((__m256i *)(y + i), y_vec);
468+
}
469+
ggml_fp32_to_fp16_row_f16c(x + i, y + i, n - i);
470+
}
471+
472+
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
473+
static ggml_from_float_t from_float_ref = NULL;
474+
if (from_float_ref != NULL) {
475+
from_float_ref(x, y, n);
476+
return;
477+
}
478+
479+
bool has_avx512f = x86_64_supports_avx512f();
480+
bool has_f16c = x86_64_supports_f16c();
481+
if (has_avx512f && has_f16c) {
482+
// use AVX512F
483+
from_float_ref = (ggml_from_float_t)ggml_fp32_to_fp16_row_avx512f;
484+
} else if (has_f16c) {
485+
// use F16C
486+
from_float_ref = (ggml_from_float_t)ggml_fp32_to_fp16_row_f16c;
487+
} else {
488+
// fallback to generic implementation
489+
from_float_ref = (ggml_from_float_t)ggml_fp32_to_fp16_generic;
490+
}
491+
type_traits[GGML_TYPE_F16].from_float_ref = from_float_ref;
492+
from_float_ref(x, y, n);
493+
}
494+
495+
#else
496+
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
497+
for (int64_t i = 0; i < n; i++) {
498+
y[i] = GGML_FP32_TO_FP16(x[i]);
499+
}
500+
}
501+
435502
#endif
436-
for (; i < n; i++) {
503+
504+
#if defined(__x86_64__)
505+
506+
507+
static inline void ggml_bf16_to_fp32_generic(const ggml_bf16_t * x, float * y, int64_t n) {
508+
for (int64_t i = 0; i < n; i++) {
437509
y[i] = GGML_BF16_TO_FP32(x[i]);
438510
}
439511
}
440512

513+
static inline void __attribute__((target("avx2"))) ggml_bf16_to_fp32_row_avx2(const ggml_bf16_t * x, float * y, int64_t n) {
514+
int64_t i = 0;
515+
for (; i + 7 < n; i += 8) {
516+
_mm256_storeu_ps(y + i,
517+
_mm256_castsi256_ps(
518+
_mm256_slli_epi32(
519+
_mm256_cvtepu16_epi32(
520+
_mm_loadu_si128(
521+
(const __m128i *)(x + i))),
522+
16)));
523+
}
524+
ggml_bf16_to_fp32_generic(x + i, y + i, n - i);
525+
}
526+
527+
static inline void __attribute__((target("avx512f"))) ggml_bf16_to_fp32_row_avx512f(const ggml_bf16_t * x, float * y, int64_t n) {
528+
int64_t i = 0;
529+
for (; i + 15 < n; i += 16) {
530+
_mm512_storeu_ps(y + i,
531+
_mm512_castsi512_ps(
532+
_mm512_slli_epi32(
533+
_mm512_cvtepu16_epi32(
534+
_mm256_loadu_si256(
535+
(const __m256i *)(x + i))),
536+
16)));
537+
}
538+
ggml_bf16_to_fp32_row_avx2(x + i, y + i, n - i);
539+
}
540+
541+
void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
542+
static ggml_to_float_t to_float = NULL;
543+
if (to_float != NULL) {
544+
to_float(x, y, n);
545+
return;
546+
}
547+
bool has_avx512f = x86_64_supports_avx512f();
548+
bool has_avx2 = x86_64_supports_avx2();
549+
if (has_avx512f) {
550+
// use AVX512F
551+
to_float = (ggml_to_float_t)ggml_bf16_to_fp32_row_avx512f;
552+
} else if (has_avx2) {
553+
// use AVX2
554+
to_float = (ggml_to_float_t)ggml_bf16_to_fp32_row_avx2;
555+
} else {
556+
// fallback to generic implementation
557+
to_float = (ggml_to_float_t)ggml_bf16_to_fp32_generic;
558+
}
559+
type_traits[GGML_TYPE_BF16].to_float = to_float;
560+
to_float(x, y, n);
561+
}
562+
563+
#else
564+
565+
void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
566+
for (int64_t i = 0; i < n; i++) {
567+
y[i] = GGML_BF16_TO_FP32(x[i]);
568+
}
569+
}
570+
#endif
571+
441572
void ggml_fp32_to_bf16_row_ref(const float * x, ggml_bf16_t * y, int64_t n) {
442573
for (int i = 0; i < n; i++) {
443574
y[i] = ggml_compute_fp32_to_bf16(x[i]);
@@ -569,7 +700,7 @@ static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const fl
569700
static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
570701
static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
571702

572-
static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
703+
static struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
573704
[GGML_TYPE_I8] = {
574705
.type_name = "i8",
575706
.blck_size = 1,

0 commit comments

Comments
 (0)