@@ -45,14 +45,24 @@ extern "C" {
45
45
// 16-bit float
46
46
// on Arm, we use __fp16
47
47
// on x86, we use uint16_t
48
- #if defined(__ARM_NEON ) && !defined( _MSC_VER )
48
+ #if defined(__ARM_NEON )
49
49
50
50
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
51
51
//
52
52
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
53
53
//
54
54
#include <arm_neon.h>
55
55
56
+ #ifdef _MSC_VER
57
+
58
+ typedef uint16_t ggml_fp16_internal_t ;
59
+
60
+ #define ggml_vld1q_u32 (w ,x ,y ,z ) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
61
+
62
+ #else
63
+
64
+ #define ggml_vld1q_u32 (w ,x ,y ,z ) { (w), (x), (y), (z) }
65
+
56
66
typedef __fp16 ggml_fp16_internal_t ;
57
67
58
68
#define GGML_COMPUTE_FP16_TO_FP32 (x ) ggml_compute_fp16_to_fp32(x)
@@ -73,8 +83,248 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
73
83
return res ;
74
84
}
75
85
86
+ #endif // _MSC_VER
87
+
88
+ #if !defined(__aarch64__ )
89
+
90
+ // 32-bit ARM compatibility
91
+
92
+ // vaddvq_s16
93
+ // vpaddq_s16
94
+ // vpaddq_s32
95
+ // vaddvq_s32
96
+ // vaddvq_f32
97
+ // vmaxvq_f32
98
+ // vcvtnq_s32_f32
99
+ // vzip1_u8
100
+ // vzip2_u8
101
+
102
+ inline static int32_t vaddvq_s16 (int16x8_t v ) {
103
+ return
104
+ (int32_t )vgetq_lane_s16 (v , 0 ) + (int32_t )vgetq_lane_s16 (v , 1 ) +
105
+ (int32_t )vgetq_lane_s16 (v , 2 ) + (int32_t )vgetq_lane_s16 (v , 3 ) +
106
+ (int32_t )vgetq_lane_s16 (v , 4 ) + (int32_t )vgetq_lane_s16 (v , 5 ) +
107
+ (int32_t )vgetq_lane_s16 (v , 6 ) + (int32_t )vgetq_lane_s16 (v , 7 );
108
+ }
109
+
110
+ inline static int16x8_t vpaddq_s16 (int16x8_t a , int16x8_t b ) {
111
+ int16x4_t a0 = vpadd_s16 (vget_low_s16 (a ), vget_high_s16 (a ));
112
+ int16x4_t b0 = vpadd_s16 (vget_low_s16 (b ), vget_high_s16 (b ));
113
+ return vcombine_s16 (a0 , b0 );
114
+ }
115
+
116
+ inline static int32x4_t vpaddq_s32 (int32x4_t a , int32x4_t b ) {
117
+ int32x2_t a0 = vpadd_s32 (vget_low_s32 (a ), vget_high_s32 (a ));
118
+ int32x2_t b0 = vpadd_s32 (vget_low_s32 (b ), vget_high_s32 (b ));
119
+ return vcombine_s32 (a0 , b0 );
120
+ }
121
+
122
+ inline static int32_t vaddvq_s32 (int32x4_t v ) {
123
+ return vgetq_lane_s32 (v , 0 ) + vgetq_lane_s32 (v , 1 ) + vgetq_lane_s32 (v , 2 ) + vgetq_lane_s32 (v , 3 );
124
+ }
125
+
126
+ inline static float vaddvq_f32 (float32x4_t v ) {
127
+ return vgetq_lane_f32 (v , 0 ) + vgetq_lane_f32 (v , 1 ) + vgetq_lane_f32 (v , 2 ) + vgetq_lane_f32 (v , 3 );
128
+ }
129
+
130
+ inline static float vmaxvq_f32 (float32x4_t v ) {
131
+ return
132
+ MAX (MAX (vgetq_lane_f32 (v , 0 ), vgetq_lane_f32 (v , 1 )),
133
+ MAX (vgetq_lane_f32 (v , 2 ), vgetq_lane_f32 (v , 3 )));
134
+ }
135
+
136
+ inline static int32x4_t vcvtnq_s32_f32 (float32x4_t v ) {
137
+ int32x4_t res ;
138
+
139
+ res [0 ] = roundf (vgetq_lane_f32 (v , 0 ));
140
+ res [1 ] = roundf (vgetq_lane_f32 (v , 1 ));
141
+ res [2 ] = roundf (vgetq_lane_f32 (v , 2 ));
142
+ res [3 ] = roundf (vgetq_lane_f32 (v , 3 ));
143
+
144
+ return res ;
145
+ }
146
+
147
+ inline static uint8x8_t vzip1_u8 (uint8x8_t a , uint8x8_t b ) {
148
+ uint8x8_t res ;
149
+
150
+ res [0 ] = a [0 ]; res [1 ] = b [0 ];
151
+ res [2 ] = a [1 ]; res [3 ] = b [1 ];
152
+ res [4 ] = a [2 ]; res [5 ] = b [2 ];
153
+ res [6 ] = a [3 ]; res [7 ] = b [3 ];
154
+
155
+ return res ;
156
+ }
157
+
158
+ inline static uint8x8_t vzip2_u8 (uint8x8_t a , uint8x8_t b ) {
159
+ uint8x8_t res ;
160
+
161
+ res [0 ] = a [4 ]; res [1 ] = b [4 ];
162
+ res [2 ] = a [5 ]; res [3 ] = b [5 ];
163
+ res [4 ] = a [6 ]; res [5 ] = b [6 ];
164
+ res [6 ] = a [7 ]; res [7 ] = b [7 ];
165
+
166
+ return res ;
167
+ }
168
+
169
+ // vld1q_s16_x2
170
+ // vld1q_u8_x2
171
+ // vld1q_u8_x4
172
+ // vld1q_s8_x2
173
+ // vld1q_s8_x4
174
+ // TODO: double-check these work correctly
175
+
176
+ typedef struct ggml_int16x8x2_t {
177
+ int16x8_t val [2 ];
178
+ } ggml_int16x8x2_t ;
179
+
180
+ inline static ggml_int16x8x2_t ggml_vld1q_s16_x2 (const int16_t * ptr ) {
181
+ ggml_int16x8x2_t res ;
182
+
183
+ res .val [0 ] = vld1q_s16 (ptr + 0 );
184
+ res .val [1 ] = vld1q_s16 (ptr + 8 );
185
+
186
+ return res ;
187
+ }
188
+
189
+ typedef struct ggml_uint8x16x2_t {
190
+ uint8x16_t val [2 ];
191
+ } ggml_uint8x16x2_t ;
192
+
193
+ inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2 (const uint8_t * ptr ) {
194
+ ggml_uint8x16x2_t res ;
195
+
196
+ res .val [0 ] = vld1q_u8 (ptr + 0 );
197
+ res .val [1 ] = vld1q_u8 (ptr + 16 );
198
+
199
+ return res ;
200
+ }
201
+
202
+ typedef struct ggml_uint8x16x4_t {
203
+ uint8x16_t val [4 ];
204
+ } ggml_uint8x16x4_t ;
205
+
206
+ inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4 (const uint8_t * ptr ) {
207
+ ggml_uint8x16x4_t res ;
208
+
209
+ res .val [0 ] = vld1q_u8 (ptr + 0 );
210
+ res .val [1 ] = vld1q_u8 (ptr + 16 );
211
+ res .val [2 ] = vld1q_u8 (ptr + 32 );
212
+ res .val [3 ] = vld1q_u8 (ptr + 48 );
213
+
214
+ return res ;
215
+ }
216
+
217
+ typedef struct ggml_int8x16x2_t {
218
+ int8x16_t val [2 ];
219
+ } ggml_int8x16x2_t ;
220
+
221
+ inline static ggml_int8x16x2_t ggml_vld1q_s8_x2 (const int8_t * ptr ) {
222
+ ggml_int8x16x2_t res ;
223
+
224
+ res .val [0 ] = vld1q_s8 (ptr + 0 );
225
+ res .val [1 ] = vld1q_s8 (ptr + 16 );
226
+
227
+ return res ;
228
+ }
229
+
230
+ typedef struct ggml_int8x16x4_t {
231
+ int8x16_t val [4 ];
232
+ } ggml_int8x16x4_t ;
233
+
234
+ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4 (const int8_t * ptr ) {
235
+ ggml_int8x16x4_t res ;
236
+
237
+ res .val [0 ] = vld1q_s8 (ptr + 0 );
238
+ res .val [1 ] = vld1q_s8 (ptr + 16 );
239
+ res .val [2 ] = vld1q_s8 (ptr + 32 );
240
+ res .val [3 ] = vld1q_s8 (ptr + 48 );
241
+
242
+ return res ;
243
+ }
244
+
245
+ // NOTE: not tested
246
+ inline static int8x16_t ggml_vqtbl1q_s8 (int8x16_t a , uint8x16_t b ) {
247
+ int8x16_t res ;
248
+
249
+ res [ 0 ] = a [b [ 0 ]];
250
+ res [ 1 ] = a [b [ 1 ]];
251
+ res [ 2 ] = a [b [ 2 ]];
252
+ res [ 3 ] = a [b [ 3 ]];
253
+ res [ 4 ] = a [b [ 4 ]];
254
+ res [ 5 ] = a [b [ 5 ]];
255
+ res [ 6 ] = a [b [ 6 ]];
256
+ res [ 7 ] = a [b [ 7 ]];
257
+ res [ 8 ] = a [b [ 8 ]];
258
+ res [ 9 ] = a [b [ 9 ]];
259
+ res [10 ] = a [b [10 ]];
260
+ res [11 ] = a [b [11 ]];
261
+ res [12 ] = a [b [12 ]];
262
+ res [13 ] = a [b [13 ]];
263
+ res [14 ] = a [b [14 ]];
264
+ res [15 ] = a [b [15 ]];
265
+
266
+ return res ;
267
+ }
268
+
269
+ // NOTE: not tested
270
+ inline static uint8x16_t ggml_vqtbl1q_u8 (uint8x16_t a , uint8x16_t b ) {
271
+ uint8x16_t res ;
272
+
273
+ res [ 0 ] = a [b [ 0 ]];
274
+ res [ 1 ] = a [b [ 1 ]];
275
+ res [ 2 ] = a [b [ 2 ]];
276
+ res [ 3 ] = a [b [ 3 ]];
277
+ res [ 4 ] = a [b [ 4 ]];
278
+ res [ 5 ] = a [b [ 5 ]];
279
+ res [ 6 ] = a [b [ 6 ]];
280
+ res [ 7 ] = a [b [ 7 ]];
281
+ res [ 8 ] = a [b [ 8 ]];
282
+ res [ 9 ] = a [b [ 9 ]];
283
+ res [10 ] = a [b [10 ]];
284
+ res [11 ] = a [b [11 ]];
285
+ res [12 ] = a [b [12 ]];
286
+ res [13 ] = a [b [13 ]];
287
+ res [14 ] = a [b [14 ]];
288
+ res [15 ] = a [b [15 ]];
289
+
290
+ return res ;
291
+ }
292
+
293
+ #else
294
+
295
+ #define ggml_int16x8x2_t int16x8x2_t
296
+ #define ggml_uint8x16x2_t uint8x16x2_t
297
+ #define ggml_uint8x16x4_t uint8x16x4_t
298
+ #define ggml_int8x16x2_t int8x16x2_t
299
+ #define ggml_int8x16x4_t int8x16x4_t
300
+
301
+ #define ggml_vld1q_s16_x2 vld1q_s16_x2
302
+ #define ggml_vld1q_u8_x2 vld1q_u8_x2
303
+ #define ggml_vld1q_u8_x4 vld1q_u8_x4
304
+ #define ggml_vld1q_s8_x2 vld1q_s8_x2
305
+ #define ggml_vld1q_s8_x4 vld1q_s8_x4
306
+ #define ggml_vqtbl1q_s8 vqtbl1q_s8
307
+ #define ggml_vqtbl1q_u8 vqtbl1q_u8
308
+
309
+ #endif // !defined(__aarch64__)
310
+
311
+ #if !defined(__ARM_FEATURE_DOTPROD )
312
+
313
+ inline static int32x4_t ggml_vdotq_s32 (int32x4_t acc , int8x16_t a , int8x16_t b ) {
314
+ const int16x8_t p0 = vmull_s8 (vget_low_s8 (a ), vget_low_s8 (b ));
315
+ const int16x8_t p1 = vmull_s8 (vget_high_s8 (a ), vget_high_s8 (b ));
316
+
317
+ return vaddq_s32 (acc , vaddq_s32 (vpaddlq_s16 (p0 ), vpaddlq_s16 (p1 )));
318
+ }
319
+
76
320
#else
77
321
322
+ #define ggml_vdotq_s32 (a , b , c ) vdotq_s32(a, b, c)
323
+
324
+ #endif // !defined(__ARM_FEATURE_DOTPROD)
325
+
326
+ #else // defined(__ARM_NEON)
327
+
78
328
typedef uint16_t ggml_fp16_internal_t ;
79
329
80
330
#ifdef __wasm_simd128__
0 commit comments