@@ -68,6 +68,114 @@ option(LLAMA_OPENBLAS "llama: use OpenBLAS"
68
68
option (LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE} )
69
69
option (LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE} )
70
70
71
+ INCLUDE (CheckCSourceRuns )
72
+
73
+ SET (AVX_CODE "
74
+ #include <immintrin.h>
75
+ int main()
76
+ {
77
+ __m256 a;
78
+ a = _mm256_set1_ps(0);
79
+ return 0;
80
+ }
81
+ " )
82
+
83
+ SET (AVX512_CODE "
84
+ #include <immintrin.h>
85
+ int main()
86
+ {
87
+ __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
88
+ 0, 0, 0, 0, 0, 0, 0, 0,
89
+ 0, 0, 0, 0, 0, 0, 0, 0,
90
+ 0, 0, 0, 0, 0, 0, 0, 0,
91
+ 0, 0, 0, 0, 0, 0, 0, 0,
92
+ 0, 0, 0, 0, 0, 0, 0, 0,
93
+ 0, 0, 0, 0, 0, 0, 0, 0,
94
+ 0, 0, 0, 0, 0, 0, 0, 0);
95
+ __m512i b = a;
96
+ __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
97
+ return 0;
98
+ }
99
+ " )
100
+
101
+ SET (AVX2_CODE "
102
+ #include <immintrin.h>
103
+ int main()
104
+ {
105
+ __m256i a = {0};
106
+ a = _mm256_abs_epi16(a);
107
+ __m256i x;
108
+ _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
109
+ return 0;
110
+ }
111
+ " )
112
+
113
+ SET (FMA_CODE "
114
+ #include <immintrin.h>
115
+ int main()
116
+ {
117
+ __m256 acc = _mm256_setzero_ps();
118
+ const __m256 d = _mm256_setzero_ps();
119
+ const __m256 p = _mm256_setzero_ps();
120
+ acc = _mm256_fmadd_ps( d, p, acc );
121
+ return 0;
122
+ }
123
+ " )
124
+
125
+ MACRO (CHECK_SSE type flags )
126
+ SET (__FLAG_I 1 )
127
+ SET (CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS} )
128
+ FOREACH (__FLAG ${flags} )
129
+ IF (NOT ${type} _FOUND )
130
+ SET (CMAKE_REQUIRED_FLAGS ${__FLAG} )
131
+ CHECK_C_SOURCE_RUNS ("${${type} _CODE}" HAS_${type}_${__FLAG_I} )
132
+ IF (HAS_${type}_${__FLAG_I} )
133
+ SET (${type} _FOUND TRUE CACHE BOOL "${type} support" )
134
+ SET (${type} _FLAGS "${__FLAG} " CACHE STRING "${type} flags" )
135
+ ENDIF ()
136
+ MATH (EXPR __FLAG_I "${__FLAG_I} +1" )
137
+ ENDIF ()
138
+ ENDFOREACH ()
139
+ SET (CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE} )
140
+
141
+ IF (NOT ${type} _FOUND )
142
+ SET (${type} _FOUND FALSE CACHE BOOL "${type} support" )
143
+ SET (${type} _FLAGS "" CACHE STRING "${type} flags" )
144
+ ENDIF ()
145
+
146
+ MARK_AS_ADVANCED (${type} _FOUND ${type} _FLAGS )
147
+
148
+ ENDMACRO ()
149
+
150
+ CHECK_SSE ("AVX" " ;-mavx;/arch:AVX" )
151
+ CHECK_SSE ("AVX2" " ;-mavx2 -mfma;/arch:AVX2" )
152
+ CHECK_SSE ("AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512" )
153
+ CHECK_SSE ("FMA" " ;-mfma;" )
154
+
155
+ IF (${AVX_FOUND} )
156
+ set (LLAMA_AVX ON )
157
+ ELSE ()
158
+ set (LLAMA_AVX OFF )
159
+ ENDIF ()
160
+
161
+ IF (${FMA_FOUND} )
162
+ set (LLAMA_FMA ON )
163
+ ELSE ()
164
+ set (LLAMA_FMA OFF )
165
+ ENDIF ()
166
+
167
+ IF (${AVX2_FOUND} )
168
+ set (LLAMA_AVX2 ON )
169
+ ELSE ()
170
+ set (LLAMA_AVX2 OFF )
171
+ ENDIF ()
172
+
173
+ IF (${AVX512_FOUND} )
174
+ set (LLAMA_AVX512 ON )
175
+ ELSE ()
176
+ set (LLAMA_AVX512 OFF )
177
+ ENDIF ()
178
+
71
179
#
72
180
# Compile flags
73
181
#
0 commit comments