Skip to content

Commit c0d1b3e

Browse files
authored
ggml : move 32-bit arm compat in ggml-impl.h (#6865)
ggml-ci
1 parent abd3314 commit c0d1b3e

File tree

2 files changed

+256
-291
lines changed

2 files changed

+256
-291
lines changed

ggml-impl.h

Lines changed: 256 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,270 @@ extern "C" {
4545
// 16-bit float
4646
// on Arm, we use __fp16
4747
// on x86, we use uint16_t
48-
#if defined(__ARM_NEON) && !defined(_MSC_VER)
48+
#if defined(__ARM_NEON)
4949

5050
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
5151
//
5252
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
5353
//
5454
#include <arm_neon.h>
5555

56+
#ifdef _MSC_VER
57+
58+
typedef uint16_t ggml_fp16_internal_t;
59+
60+
#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
61+
62+
#else
63+
5664
typedef __fp16 ggml_fp16_internal_t;
5765

66+
#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
67+
68+
#endif // _MSC_VER
69+
70+
#if !defined(__aarch64__)
71+
72+
// 32-bit ARM compatibility
73+
74+
// vaddvq_s16
75+
// vpaddq_s16
76+
// vpaddq_s32
77+
// vaddvq_s32
78+
// vaddvq_f32
79+
// vmaxvq_f32
80+
// vcvtnq_s32_f32
81+
// vzip1_u8
82+
// vzip2_u8
83+
84+
inline static int32_t vaddvq_s16(int16x8_t v) {
85+
return
86+
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
87+
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
88+
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
89+
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
90+
}
91+
92+
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
93+
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
94+
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
95+
return vcombine_s16(a0, b0);
96+
}
97+
98+
inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
99+
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
100+
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
101+
return vcombine_s32(a0, b0);
102+
}
103+
104+
inline static int32_t vaddvq_s32(int32x4_t v) {
105+
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
106+
}
107+
108+
inline static float vaddvq_f32(float32x4_t v) {
109+
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
110+
}
111+
112+
inline static float vmaxvq_f32(float32x4_t v) {
113+
return
114+
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
115+
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
116+
}
117+
118+
inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
119+
int32x4_t res;
120+
121+
res[0] = roundf(vgetq_lane_f32(v, 0));
122+
res[1] = roundf(vgetq_lane_f32(v, 1));
123+
res[2] = roundf(vgetq_lane_f32(v, 2));
124+
res[3] = roundf(vgetq_lane_f32(v, 3));
125+
126+
return res;
127+
}
128+
129+
inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
130+
uint8x8_t res;
131+
132+
res[0] = a[0]; res[1] = b[0];
133+
res[2] = a[1]; res[3] = b[1];
134+
res[4] = a[2]; res[5] = b[2];
135+
res[6] = a[3]; res[7] = b[3];
136+
137+
return res;
138+
}
139+
140+
inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
141+
uint8x8_t res;
142+
143+
res[0] = a[4]; res[1] = b[4];
144+
res[2] = a[5]; res[3] = b[5];
145+
res[4] = a[6]; res[5] = b[6];
146+
res[6] = a[7]; res[7] = b[7];
147+
148+
return res;
149+
}
150+
151+
// vld1q_s16_x2
152+
// vld1q_u8_x2
153+
// vld1q_u8_x4
154+
// vld1q_s8_x2
155+
// vld1q_s8_x4
156+
// TODO: double-check these work correctly
157+
158+
typedef struct ggml_int16x8x2_t {
159+
int16x8_t val[2];
160+
} ggml_int16x8x2_t;
161+
162+
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
163+
ggml_int16x8x2_t res;
164+
165+
res.val[0] = vld1q_s16(ptr + 0);
166+
res.val[1] = vld1q_s16(ptr + 8);
167+
168+
return res;
169+
}
170+
171+
typedef struct ggml_uint8x16x2_t {
172+
uint8x16_t val[2];
173+
} ggml_uint8x16x2_t;
174+
175+
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
176+
ggml_uint8x16x2_t res;
177+
178+
res.val[0] = vld1q_u8(ptr + 0);
179+
res.val[1] = vld1q_u8(ptr + 16);
180+
181+
return res;
182+
}
183+
184+
typedef struct ggml_uint8x16x4_t {
185+
uint8x16_t val[4];
186+
} ggml_uint8x16x4_t;
187+
188+
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
189+
ggml_uint8x16x4_t res;
190+
191+
res.val[0] = vld1q_u8(ptr + 0);
192+
res.val[1] = vld1q_u8(ptr + 16);
193+
res.val[2] = vld1q_u8(ptr + 32);
194+
res.val[3] = vld1q_u8(ptr + 48);
195+
196+
return res;
197+
}
198+
199+
typedef struct ggml_int8x16x2_t {
200+
int8x16_t val[2];
201+
} ggml_int8x16x2_t;
202+
203+
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
204+
ggml_int8x16x2_t res;
205+
206+
res.val[0] = vld1q_s8(ptr + 0);
207+
res.val[1] = vld1q_s8(ptr + 16);
208+
209+
return res;
210+
}
211+
212+
typedef struct ggml_int8x16x4_t {
213+
int8x16_t val[4];
214+
} ggml_int8x16x4_t;
215+
216+
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
217+
ggml_int8x16x4_t res;
218+
219+
res.val[0] = vld1q_s8(ptr + 0);
220+
res.val[1] = vld1q_s8(ptr + 16);
221+
res.val[2] = vld1q_s8(ptr + 32);
222+
res.val[3] = vld1q_s8(ptr + 48);
223+
224+
return res;
225+
}
226+
227+
// NOTE: not tested
228+
inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
229+
int8x16_t res;
230+
231+
res[ 0] = a[b[ 0]];
232+
res[ 1] = a[b[ 1]];
233+
res[ 2] = a[b[ 2]];
234+
res[ 3] = a[b[ 3]];
235+
res[ 4] = a[b[ 4]];
236+
res[ 5] = a[b[ 5]];
237+
res[ 6] = a[b[ 6]];
238+
res[ 7] = a[b[ 7]];
239+
res[ 8] = a[b[ 8]];
240+
res[ 9] = a[b[ 9]];
241+
res[10] = a[b[10]];
242+
res[11] = a[b[11]];
243+
res[12] = a[b[12]];
244+
res[13] = a[b[13]];
245+
res[14] = a[b[14]];
246+
res[15] = a[b[15]];
247+
248+
return res;
249+
}
250+
251+
// NOTE: not tested
252+
inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
253+
uint8x16_t res;
254+
255+
res[ 0] = a[b[ 0]];
256+
res[ 1] = a[b[ 1]];
257+
res[ 2] = a[b[ 2]];
258+
res[ 3] = a[b[ 3]];
259+
res[ 4] = a[b[ 4]];
260+
res[ 5] = a[b[ 5]];
261+
res[ 6] = a[b[ 6]];
262+
res[ 7] = a[b[ 7]];
263+
res[ 8] = a[b[ 8]];
264+
res[ 9] = a[b[ 9]];
265+
res[10] = a[b[10]];
266+
res[11] = a[b[11]];
267+
res[12] = a[b[12]];
268+
res[13] = a[b[13]];
269+
res[14] = a[b[14]];
270+
res[15] = a[b[15]];
271+
272+
return res;
273+
}
274+
275+
#else
276+
277+
#define ggml_int16x8x2_t int16x8x2_t
278+
#define ggml_uint8x16x2_t uint8x16x2_t
279+
#define ggml_uint8x16x4_t uint8x16x4_t
280+
#define ggml_int8x16x2_t int8x16x2_t
281+
#define ggml_int8x16x4_t int8x16x4_t
282+
283+
#define ggml_vld1q_s16_x2 vld1q_s16_x2
284+
#define ggml_vld1q_u8_x2 vld1q_u8_x2
285+
#define ggml_vld1q_u8_x4 vld1q_u8_x4
286+
#define ggml_vld1q_s8_x2 vld1q_s8_x2
287+
#define ggml_vld1q_s8_x4 vld1q_s8_x4
288+
#define ggml_vqtbl1q_s8 vqtbl1q_s8
289+
#define ggml_vqtbl1q_u8 vqtbl1q_u8
290+
291+
#endif // !defined(__aarch64__)
292+
293+
#if !defined(__ARM_FEATURE_DOTPROD)
294+
295+
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
296+
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
297+
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
298+
299+
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
300+
}
301+
302+
#else
303+
304+
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
305+
306+
#endif // !defined(__ARM_FEATURE_DOTPROD)
307+
308+
#endif // defined(__ARM_NEON)
309+
310+
#if defined(__ARM_NEON) && !defined(__MSC_VER)
311+
58312
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
59313
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
60314

@@ -75,8 +329,6 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
75329

76330
#else
77331

78-
typedef uint16_t ggml_fp16_internal_t;
79-
80332
#ifdef __wasm_simd128__
81333
#include <wasm_simd128.h>
82334
#else
@@ -221,7 +473,7 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
221473

222474
#endif // __F16C__
223475

224-
#endif // __ARM_NEON
476+
#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
225477

226478
// precomputed f32 table for f16 (256 KB)
227479
// defined in ggml.c, initialized in ggml_init()

0 commit comments

Comments
 (0)