Skip to content

Commit e384297

Browse files
Srihari-mcwslaren
authored andcommitted
ggml : fixes for AVXVNNI instruction set with MSVC and Clang (ggml-org#11027)
* Fixes for clang AVX VNNI * enable AVX VNNI and alder lake build for MSVC * Apply suggestions from code review --------- Co-authored-by: slaren <[email protected]>
1 parent 7acec5c commit e384297

File tree

5 files changed

+15
-7
lines changed

5 files changed

+15
-7
lines changed

ggml/src/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ if (GGML_CPU_ALL_VARIANTS)
290290
ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 FMA)
291291
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
292292
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
293+
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
293294
if (NOT MSVC)
294-
# MSVC doesn't support AVX-VNNI or AMX
295-
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
295+
# MSVC doesn't support AMX
296296
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
297297
endif()
298298
else ()

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
215215
list(APPEND ARCH_DEFINITIONS GGML_SSE42)
216216
endif()
217217
if (GGML_AVX_VNNI)
218-
# MSVC generates AVX512 with AVX-VNNI intrinsics even with /arch:AVX2
219-
#list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
218+
list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
220219
endif()
221220
else ()
222221
if (GGML_NATIVE)

ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
194194
}
195195

196196
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
197-
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
197+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
198198
const __m256i zero = _mm256_setzero_si256();
199199
return _mm256_dpbusd_epi32(zero, ax, sy);
200+
#elif defined(__AVXVNNI__)
201+
const __m256i zero = _mm256_setzero_si256();
202+
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
200203
#else
201204
// Perform multiplication and create 16-bit values
202205
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
103103
}
104104

105105
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
106-
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
106+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
107107
const __m256i zero = _mm256_setzero_si256();
108108
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
109109
return _mm256_cvtepi32_ps(summed_pairs);
110+
#elif defined(__AVXVNNI__)
111+
const __m256i zero = _mm256_setzero_si256();
112+
const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
113+
return _mm256_cvtepi32_ps(summed_pairs);
110114
#else
111115
// Perform multiplication and create 16-bit values
112116
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1000,8 +1000,10 @@ class tinyBLAS_Q0_AVX {
10001000

10011001
inline __m256 updot(__m256i u, __m256i s) {
10021002
__m256i res;
1003-
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
1003+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
10041004
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
1005+
#elif defined(__AVXVNNI__)
1006+
res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
10051007
#else
10061008
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
10071009
#endif

0 commit comments

Comments
 (0)