62
62
#define NOINLINE __attribute__ ((__noinline__))
63
63
#endif
64
64
65
- #if defined(__ARM_NEON) || defined(__AVX512F__)
65
+ #if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
66
66
#define VECTOR_REGISTERS 32
67
67
#else
68
68
#define VECTOR_REGISTERS 16
@@ -109,6 +109,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
109
109
inline float16x8_t mul (float16x8_t x, float16x8_t y) { return vmulq_f16 (x, y); }
110
110
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
111
111
112
+ #if defined(__VXE__) || defined(__VXE2__)
113
+ inline float32x4_t add (float32x4_t x, float32x4_t y) { return vec_add (x, y); }
114
+ inline float32x4_t sub (float32x4_t x, float32x4_t y) { return vec_sub (x, y); }
115
+ inline float32x4_t mul (float32x4_t x, float32x4_t y) { return vec_mul (x, y); }
116
+ #endif
117
+
112
118
#if defined(__MMA__)
113
119
typedef vector unsigned char vec_t ;
114
120
typedef __vector_quad acc_t ;
@@ -162,6 +168,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
162
168
#endif
163
169
#endif
164
170
171
+ #if defined(__VXE__) || defined(__VXE2__)
172
+ template <>
173
+ inline float32x4_t madd (float32x4_t a, float32x4_t b, float32x4_t c) {
174
+ return vec_madd (a, b, c);
175
+ }
176
+ #endif
177
+
165
178
// //////////////////////////////////////////////////////////////////////////////////////////////////
166
179
// VECTORIZED HORIZONTAL SUM
167
180
@@ -178,6 +191,13 @@ inline float hsum(float16x8_t x) {
178
191
}
179
192
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
180
193
194
+ #if defined(__VXE__) || defined(__VXE2__)
195
+ inline float hsum (float32x4_t x) {
196
+ float32x4_t tmp = x + vec_reve (x);
197
+ return tmp[0 ] + tmp[1 ];
198
+ }
199
+ #endif
200
+
181
201
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
182
202
inline float hsum (__m128 x) {
183
203
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
@@ -227,6 +247,21 @@ template <> inline float32x4_t load(const ggml_fp16_t *p) {
227
247
#endif // _MSC_VER
228
248
#endif // __ARM_NEON
229
249
250
+ #if defined(__VXE__) || defined(__VXE2__)
251
+ template <> inline float32x4_t load (const ggml_fp16_t * p) {
252
+ float tmp[4 ];
253
+
254
+ for (int i = 0 ; i < 4 ; i++) {
255
+ tmp[i] = GGML_FP16_TO_FP32 (p[i]);
256
+ }
257
+
258
+ return vec_xl (0 , (const float *)(tmp));
259
+ }
260
+ template <> inline float32x4_t load (const float * p) {
261
+ return vec_xl (0 , p);
262
+ }
263
+ #endif
264
+
230
265
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
231
266
template <> inline __m128 load (const float *p) {
232
267
return _mm_loadu_ps (p);
@@ -3319,6 +3354,14 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3319
3354
(const float *)B, ldb,
3320
3355
(float *)C, ldc};
3321
3356
return tb.matmul (m, n);
3357
+ #elif defined(__VXE__) || defined(__VXE2__)
3358
+ if (n < 4 )
3359
+ return false ;
3360
+ tinyBLAS<4 , float32x4_t , float32x4_t , float , float , float > tb{ params,
3361
+ k, (const float *)A, lda,
3362
+ (const float *)B, ldb,
3363
+ (float *)C, ldc};
3364
+ return tb.matmul (m, n);
3322
3365
#elif defined(__MMA__)
3323
3366
if (k % 8 )
3324
3367
return false ;
@@ -3410,6 +3453,16 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3410
3453
(float *)C, ldc};
3411
3454
return tb.matmul (m, n);
3412
3455
}
3456
+ #elif defined(__VXE__) || defined(__VXE2__)
3457
+ if (n < 4 )
3458
+ return false ;
3459
+ if (Btype == GGML_TYPE_F16) {
3460
+ tinyBLAS<4 , float32x4_t , float32x4_t , ggml_fp16_t , ggml_fp16_t , float > tb{ params,
3461
+ k, (const ggml_fp16_t *)A, lda,
3462
+ (const ggml_fp16_t *)B, ldb,
3463
+ (float *)C, ldc};
3464
+ return tb.matmul (m, n);
3465
+ }
3413
3466
#endif
3414
3467
return false ;
3415
3468
}
0 commit comments