Skip to content

Commit b7034f1

Browse files
committed
Add detection code for avx
1 parent 3173a62 commit b7034f1

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

CMakeLists.txt

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,114 @@ option(LLAMA_OPENBLAS "llama: use OpenBLAS"
6868
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
6969
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
7070

71+
INCLUDE(CheckCSourceRuns)
72+
73+
SET(AVX_CODE "
74+
#include <immintrin.h>
75+
int main()
76+
{
77+
__m256 a;
78+
a = _mm256_set1_ps(0);
79+
return 0;
80+
}
81+
")
82+
83+
SET(AVX512_CODE "
84+
#include <immintrin.h>
85+
int main()
86+
{
87+
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
88+
0, 0, 0, 0, 0, 0, 0, 0,
89+
0, 0, 0, 0, 0, 0, 0, 0,
90+
0, 0, 0, 0, 0, 0, 0, 0,
91+
0, 0, 0, 0, 0, 0, 0, 0,
92+
0, 0, 0, 0, 0, 0, 0, 0,
93+
0, 0, 0, 0, 0, 0, 0, 0,
94+
0, 0, 0, 0, 0, 0, 0, 0);
95+
__m512i b = a;
96+
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
97+
return 0;
98+
}
99+
")
100+
101+
SET(AVX2_CODE "
102+
#include <immintrin.h>
103+
int main()
104+
{
105+
__m256i a = {0};
106+
a = _mm256_abs_epi16(a);
107+
__m256i x;
108+
_mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
109+
return 0;
110+
}
111+
")
112+
113+
SET(FMA_CODE "
114+
#include <immintrin.h>
115+
int main()
116+
{
117+
__m256 acc = _mm256_setzero_ps();
118+
const __m256 d = _mm256_setzero_ps();
119+
const __m256 p = _mm256_setzero_ps();
120+
acc = _mm256_fmadd_ps( d, p, acc );
121+
return 0;
122+
}
123+
")
124+
125+
MACRO(CHECK_SSE type flags)
126+
SET(__FLAG_I 1)
127+
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
128+
FOREACH(__FLAG ${flags})
129+
IF(NOT ${type}_FOUND)
130+
SET(CMAKE_REQUIRED_FLAGS ${__FLAG})
131+
CHECK_C_SOURCE_RUNS("${${type}_CODE}" HAS_${type}_${__FLAG_I})
132+
IF(HAS_${type}_${__FLAG_I})
133+
SET(${type}_FOUND TRUE CACHE BOOL "${type} support")
134+
SET(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
135+
ENDIF()
136+
MATH(EXPR __FLAG_I "${__FLAG_I}+1")
137+
ENDIF()
138+
ENDFOREACH()
139+
SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
140+
141+
IF(NOT ${type}_FOUND)
142+
SET(${type}_FOUND FALSE CACHE BOOL "${type} support")
143+
SET(${type}_FLAGS "" CACHE STRING "${type} flags")
144+
ENDIF()
145+
146+
MARK_AS_ADVANCED(${type}_FOUND ${type}_FLAGS)
147+
148+
ENDMACRO()
149+
150+
CHECK_SSE("AVX" " ;-mavx;/arch:AVX")
151+
CHECK_SSE("AVX2" " ;-mavx2 -mfma;/arch:AVX2")
152+
CHECK_SSE("AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
153+
CHECK_SSE("FMA" " ;-mfma;")
154+
155+
IF(${AVX_FOUND})
156+
set(LLAMA_AVX ON)
157+
ELSE()
158+
set(LLAMA_AVX OFF)
159+
ENDIF()
160+
161+
IF (${FMA_FOUND})
162+
set(LLAMA_FMA ON)
163+
ELSE()
164+
set(LLAMA_FMA OFF)
165+
ENDIF()
166+
167+
IF(${AVX2_FOUND})
168+
set(LLAMA_AVX2 ON)
169+
ELSE()
170+
set(LLAMA_AVX2 OFF)
171+
ENDIF()
172+
173+
IF(${AVX512_FOUND})
174+
set(LLAMA_AVX512 ON)
175+
ELSE()
176+
set(LLAMA_AVX512 OFF)
177+
ENDIF()
178+
71179
#
72180
# Compile flags
73181
#

0 commit comments

Comments
 (0)