@@ -45,16 +45,270 @@ extern "C" {
45
45
// 16-bit float
46
46
// on Arm, we use __fp16
47
47
// on x86, we use uint16_t
48
- #if defined(__ARM_NEON ) && !defined( _MSC_VER )
48
+ #if defined(__ARM_NEON )
49
49
50
50
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
51
51
//
52
52
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
53
53
//
54
54
#include <arm_neon.h>
55
55
56
+ #ifdef _MSC_VER
57
+
58
+ typedef uint16_t ggml_fp16_internal_t ;
59
+
60
+ #define ggml_vld1q_u32 (w ,x ,y ,z ) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
61
+
62
+ #else
63
+
56
64
typedef __fp16 ggml_fp16_internal_t ;
57
65
66
+ #define ggml_vld1q_u32 (w ,x ,y ,z ) { (w), (x), (y), (z) }
67
+
68
+ #endif // _MSC_VER
69
+
70
+ #if !defined(__aarch64__ )
71
+
72
+ // 32-bit ARM compatibility
73
+
74
+ // vaddvq_s16
75
+ // vpaddq_s16
76
+ // vpaddq_s32
77
+ // vaddvq_s32
78
+ // vaddvq_f32
79
+ // vmaxvq_f32
80
+ // vcvtnq_s32_f32
81
+ // vzip1_u8
82
+ // vzip2_u8
83
+
84
+ inline static int32_t vaddvq_s16 (int16x8_t v ) {
85
+ return
86
+ (int32_t )vgetq_lane_s16 (v , 0 ) + (int32_t )vgetq_lane_s16 (v , 1 ) +
87
+ (int32_t )vgetq_lane_s16 (v , 2 ) + (int32_t )vgetq_lane_s16 (v , 3 ) +
88
+ (int32_t )vgetq_lane_s16 (v , 4 ) + (int32_t )vgetq_lane_s16 (v , 5 ) +
89
+ (int32_t )vgetq_lane_s16 (v , 6 ) + (int32_t )vgetq_lane_s16 (v , 7 );
90
+ }
91
+
92
+ inline static int16x8_t vpaddq_s16 (int16x8_t a , int16x8_t b ) {
93
+ int16x4_t a0 = vpadd_s16 (vget_low_s16 (a ), vget_high_s16 (a ));
94
+ int16x4_t b0 = vpadd_s16 (vget_low_s16 (b ), vget_high_s16 (b ));
95
+ return vcombine_s16 (a0 , b0 );
96
+ }
97
+
98
+ inline static int32x4_t vpaddq_s32 (int32x4_t a , int32x4_t b ) {
99
+ int32x2_t a0 = vpadd_s32 (vget_low_s32 (a ), vget_high_s32 (a ));
100
+ int32x2_t b0 = vpadd_s32 (vget_low_s32 (b ), vget_high_s32 (b ));
101
+ return vcombine_s32 (a0 , b0 );
102
+ }
103
+
104
+ inline static int32_t vaddvq_s32 (int32x4_t v ) {
105
+ return vgetq_lane_s32 (v , 0 ) + vgetq_lane_s32 (v , 1 ) + vgetq_lane_s32 (v , 2 ) + vgetq_lane_s32 (v , 3 );
106
+ }
107
+
108
+ inline static float vaddvq_f32 (float32x4_t v ) {
109
+ return vgetq_lane_f32 (v , 0 ) + vgetq_lane_f32 (v , 1 ) + vgetq_lane_f32 (v , 2 ) + vgetq_lane_f32 (v , 3 );
110
+ }
111
+
112
+ inline static float vmaxvq_f32 (float32x4_t v ) {
113
+ return
114
+ MAX (MAX (vgetq_lane_f32 (v , 0 ), vgetq_lane_f32 (v , 1 )),
115
+ MAX (vgetq_lane_f32 (v , 2 ), vgetq_lane_f32 (v , 3 )));
116
+ }
117
+
118
+ inline static int32x4_t vcvtnq_s32_f32 (float32x4_t v ) {
119
+ int32x4_t res ;
120
+
121
+ res [0 ] = roundf (vgetq_lane_f32 (v , 0 ));
122
+ res [1 ] = roundf (vgetq_lane_f32 (v , 1 ));
123
+ res [2 ] = roundf (vgetq_lane_f32 (v , 2 ));
124
+ res [3 ] = roundf (vgetq_lane_f32 (v , 3 ));
125
+
126
+ return res ;
127
+ }
128
+
129
+ inline static uint8x8_t vzip1_u8 (uint8x8_t a , uint8x8_t b ) {
130
+ uint8x8_t res ;
131
+
132
+ res [0 ] = a [0 ]; res [1 ] = b [0 ];
133
+ res [2 ] = a [1 ]; res [3 ] = b [1 ];
134
+ res [4 ] = a [2 ]; res [5 ] = b [2 ];
135
+ res [6 ] = a [3 ]; res [7 ] = b [3 ];
136
+
137
+ return res ;
138
+ }
139
+
140
+ inline static uint8x8_t vzip2_u8 (uint8x8_t a , uint8x8_t b ) {
141
+ uint8x8_t res ;
142
+
143
+ res [0 ] = a [4 ]; res [1 ] = b [4 ];
144
+ res [2 ] = a [5 ]; res [3 ] = b [5 ];
145
+ res [4 ] = a [6 ]; res [5 ] = b [6 ];
146
+ res [6 ] = a [7 ]; res [7 ] = b [7 ];
147
+
148
+ return res ;
149
+ }
150
+
151
+ // vld1q_s16_x2
152
+ // vld1q_u8_x2
153
+ // vld1q_u8_x4
154
+ // vld1q_s8_x2
155
+ // vld1q_s8_x4
156
+ // TODO: double-check these work correctly
157
+
158
+ typedef struct ggml_int16x8x2_t {
159
+ int16x8_t val [2 ];
160
+ } ggml_int16x8x2_t ;
161
+
162
+ inline static ggml_int16x8x2_t ggml_vld1q_s16_x2 (const int16_t * ptr ) {
163
+ ggml_int16x8x2_t res ;
164
+
165
+ res .val [0 ] = vld1q_s16 (ptr + 0 );
166
+ res .val [1 ] = vld1q_s16 (ptr + 8 );
167
+
168
+ return res ;
169
+ }
170
+
171
+ typedef struct ggml_uint8x16x2_t {
172
+ uint8x16_t val [2 ];
173
+ } ggml_uint8x16x2_t ;
174
+
175
+ inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2 (const uint8_t * ptr ) {
176
+ ggml_uint8x16x2_t res ;
177
+
178
+ res .val [0 ] = vld1q_u8 (ptr + 0 );
179
+ res .val [1 ] = vld1q_u8 (ptr + 16 );
180
+
181
+ return res ;
182
+ }
183
+
184
+ typedef struct ggml_uint8x16x4_t {
185
+ uint8x16_t val [4 ];
186
+ } ggml_uint8x16x4_t ;
187
+
188
+ inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4 (const uint8_t * ptr ) {
189
+ ggml_uint8x16x4_t res ;
190
+
191
+ res .val [0 ] = vld1q_u8 (ptr + 0 );
192
+ res .val [1 ] = vld1q_u8 (ptr + 16 );
193
+ res .val [2 ] = vld1q_u8 (ptr + 32 );
194
+ res .val [3 ] = vld1q_u8 (ptr + 48 );
195
+
196
+ return res ;
197
+ }
198
+
199
+ typedef struct ggml_int8x16x2_t {
200
+ int8x16_t val [2 ];
201
+ } ggml_int8x16x2_t ;
202
+
203
+ inline static ggml_int8x16x2_t ggml_vld1q_s8_x2 (const int8_t * ptr ) {
204
+ ggml_int8x16x2_t res ;
205
+
206
+ res .val [0 ] = vld1q_s8 (ptr + 0 );
207
+ res .val [1 ] = vld1q_s8 (ptr + 16 );
208
+
209
+ return res ;
210
+ }
211
+
212
+ typedef struct ggml_int8x16x4_t {
213
+ int8x16_t val [4 ];
214
+ } ggml_int8x16x4_t ;
215
+
216
+ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4 (const int8_t * ptr ) {
217
+ ggml_int8x16x4_t res ;
218
+
219
+ res .val [0 ] = vld1q_s8 (ptr + 0 );
220
+ res .val [1 ] = vld1q_s8 (ptr + 16 );
221
+ res .val [2 ] = vld1q_s8 (ptr + 32 );
222
+ res .val [3 ] = vld1q_s8 (ptr + 48 );
223
+
224
+ return res ;
225
+ }
226
+
227
+ // NOTE: not tested
228
+ inline static int8x16_t ggml_vqtbl1q_s8 (int8x16_t a , uint8x16_t b ) {
229
+ int8x16_t res ;
230
+
231
+ res [ 0 ] = a [b [ 0 ]];
232
+ res [ 1 ] = a [b [ 1 ]];
233
+ res [ 2 ] = a [b [ 2 ]];
234
+ res [ 3 ] = a [b [ 3 ]];
235
+ res [ 4 ] = a [b [ 4 ]];
236
+ res [ 5 ] = a [b [ 5 ]];
237
+ res [ 6 ] = a [b [ 6 ]];
238
+ res [ 7 ] = a [b [ 7 ]];
239
+ res [ 8 ] = a [b [ 8 ]];
240
+ res [ 9 ] = a [b [ 9 ]];
241
+ res [10 ] = a [b [10 ]];
242
+ res [11 ] = a [b [11 ]];
243
+ res [12 ] = a [b [12 ]];
244
+ res [13 ] = a [b [13 ]];
245
+ res [14 ] = a [b [14 ]];
246
+ res [15 ] = a [b [15 ]];
247
+
248
+ return res ;
249
+ }
250
+
251
+ // NOTE: not tested
252
+ inline static uint8x16_t ggml_vqtbl1q_u8 (uint8x16_t a , uint8x16_t b ) {
253
+ uint8x16_t res ;
254
+
255
+ res [ 0 ] = a [b [ 0 ]];
256
+ res [ 1 ] = a [b [ 1 ]];
257
+ res [ 2 ] = a [b [ 2 ]];
258
+ res [ 3 ] = a [b [ 3 ]];
259
+ res [ 4 ] = a [b [ 4 ]];
260
+ res [ 5 ] = a [b [ 5 ]];
261
+ res [ 6 ] = a [b [ 6 ]];
262
+ res [ 7 ] = a [b [ 7 ]];
263
+ res [ 8 ] = a [b [ 8 ]];
264
+ res [ 9 ] = a [b [ 9 ]];
265
+ res [10 ] = a [b [10 ]];
266
+ res [11 ] = a [b [11 ]];
267
+ res [12 ] = a [b [12 ]];
268
+ res [13 ] = a [b [13 ]];
269
+ res [14 ] = a [b [14 ]];
270
+ res [15 ] = a [b [15 ]];
271
+
272
+ return res ;
273
+ }
274
+
275
+ #else
276
+
277
+ #define ggml_int16x8x2_t int16x8x2_t
278
+ #define ggml_uint8x16x2_t uint8x16x2_t
279
+ #define ggml_uint8x16x4_t uint8x16x4_t
280
+ #define ggml_int8x16x2_t int8x16x2_t
281
+ #define ggml_int8x16x4_t int8x16x4_t
282
+
283
+ #define ggml_vld1q_s16_x2 vld1q_s16_x2
284
+ #define ggml_vld1q_u8_x2 vld1q_u8_x2
285
+ #define ggml_vld1q_u8_x4 vld1q_u8_x4
286
+ #define ggml_vld1q_s8_x2 vld1q_s8_x2
287
+ #define ggml_vld1q_s8_x4 vld1q_s8_x4
288
+ #define ggml_vqtbl1q_s8 vqtbl1q_s8
289
+ #define ggml_vqtbl1q_u8 vqtbl1q_u8
290
+
291
+ #endif // !defined(__aarch64__)
292
+
293
+ #if !defined(__ARM_FEATURE_DOTPROD )
294
+
295
+ inline static int32x4_t ggml_vdotq_s32 (int32x4_t acc , int8x16_t a , int8x16_t b ) {
296
+ const int16x8_t p0 = vmull_s8 (vget_low_s8 (a ), vget_low_s8 (b ));
297
+ const int16x8_t p1 = vmull_s8 (vget_high_s8 (a ), vget_high_s8 (b ));
298
+
299
+ return vaddq_s32 (acc , vaddq_s32 (vpaddlq_s16 (p0 ), vpaddlq_s16 (p1 )));
300
+ }
301
+
302
+ #else
303
+
304
+ #define ggml_vdotq_s32 (a , b , c ) vdotq_s32(a, b, c)
305
+
306
+ #endif // !defined(__ARM_FEATURE_DOTPROD)
307
+
308
+ #endif // defined(__ARM_NEON)
309
+
310
+ #if defined(__ARM_NEON ) && !defined(__MSC_VER )
311
+
58
312
#define GGML_COMPUTE_FP16_TO_FP32 (x ) ggml_compute_fp16_to_fp32(x)
59
313
#define GGML_COMPUTE_FP32_TO_FP16 (x ) ggml_compute_fp32_to_fp16(x)
60
314
@@ -75,8 +329,6 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
75
329
76
330
#else
77
331
78
- typedef uint16_t ggml_fp16_internal_t ;
79
-
80
332
#ifdef __wasm_simd128__
81
333
#include <wasm_simd128.h>
82
334
#else
@@ -221,7 +473,7 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
221
473
222
474
#endif // __F16C__
223
475
224
- #endif // __ARM_NEON
476
+ #endif // defined( __ARM_NEON) && (!defined(__MSC_VER)
225
477
226
478
// precomputed f32 table for f16 (256 KB)
227
479
// defined in ggml.c, initialized in ggml_init()
0 commit comments