Skip to content

Commit eedc38a

Browse files
swolchokfacebook-github-bot
authored andcommitted
FFHT: ARM NEON port (#5289)
Summary: Pull Request resolved: #5289 Patch the code generator to be capable of generating NEON code and leave it configured to do that since we already have the checked-in generated AVX and SSE code. Generated code size was a potential issue so I also patched the generator to 1) reuse generated code for previous smaller sizes whereapplicable and 2) choose the smallest code that isn't more than 10% slower than the very fastest code. ghstack-source-id: 242230777 exported-using-ghexport Reviewed By: kimishpatel Differential Revision: D60194970 fbshipit-source-id: 37aab6813222c5a965c060286b5d5453ced22a0c
1 parent d3fb502 commit eedc38a

File tree

5 files changed

+3709
-372
lines changed

5 files changed

+3709
-372
lines changed

extension/llm/custom_ops/spinquant/third-party/FFHT/fht.h

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,44 @@
11
#ifndef _FHT_H_
22
#define _FHT_H_
3-
#include <string.h>
43
#include <stdlib.h>
4+
#include <string.h>
55

66
#ifdef __cplusplus
77
extern "C" {
88
#endif
99

10-
int fht_float(float *buf, int log_n);
11-
int fht_double(double *buf, int log_n);
12-
int fht_float_oop(float *in, float *out, int log_n);
13-
int fht_double_oop(double *in, double *out, int log_n);
14-
10+
int fht_float(float* buf, int log_n);
11+
#ifndef __aarch64__
12+
int fht_double(double* buf, int log_n);
13+
#endif
14+
int fht_float_oop(float* in, float* out, int log_n);
15+
#ifndef __aarch64__
16+
int fht_double_oop(double* in, double* out, int log_n);
17+
#endif
1518

1619
#ifdef __cplusplus
1720

1821
} // extern "C"
1922

20-
static inline int fht(float *buf, int log_n) {
21-
return fht_float(buf, log_n);
23+
static inline int fht(float* buf, int log_n) {
24+
return fht_float(buf, log_n);
2225
}
2326

24-
static inline int fht(double *buf, int log_n) {
25-
return fht_double(buf, log_n);
27+
#ifndef __aarch64__
28+
static inline int fht(double* buf, int log_n) {
29+
return fht_double(buf, log_n);
2630
}
31+
#endif
2732

28-
static inline int fht(float *buf, float *out, int log_n) {
29-
return fht_float_oop(buf, out, log_n);
33+
static inline int fht(float* buf, float* out, int log_n) {
34+
return fht_float_oop(buf, out, log_n);
3035
}
3136

32-
static inline int fht(double *buf, double *out, int log_n) {
33-
return fht_double_oop(buf, out, log_n);
37+
#ifndef __aarch64__
38+
static inline int fht(double* buf, double* out, int log_n) {
39+
return fht_double_oop(buf, out, log_n);
3440
}
41+
#endif
3542

3643
#endif
3744

extension/llm/custom_ops/spinquant/third-party/FFHT/fht_impl.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,30 @@
77
extern "C" {
88
#endif
99

10+
#ifdef __aarch64__
11+
#include "fht_neon.c"
12+
#define VECTOR_WIDTH (16u)
13+
#else
1014
#ifdef __AVX__
1115
#include "fht_avx.c"
1216
#define VECTOR_WIDTH (32u)
1317
#else
1418
#include "fht_sse.c"
1519
#define VECTOR_WIDTH (16u)
1620
#endif
21+
#endif
1722

18-
int fht_float_oop(float *in, float *out, int log_n) {
19-
fast_copy(out, in, sizeof(float) << log_n);
20-
return fht_float(out, log_n);
23+
int fht_float_oop(float* in, float* out, int log_n) {
24+
fast_copy(out, in, sizeof(float) << log_n);
25+
return fht_float(out, log_n);
2126
}
2227

23-
int fht_double_oop(double *in, double *out, int log_n) {
24-
fast_copy(out, in, sizeof(double) << log_n);
25-
return fht_double(out, log_n);
28+
#ifndef __aarch64__
29+
int fht_double_oop(double* in, double* out, int log_n) {
30+
fast_copy(out, in, sizeof(double) << log_n);
31+
return fht_double(out, log_n);
2632
}
33+
#endif
2734

2835
#ifdef __cplusplus
2936
} // extern "C"

0 commit comments

Comments
 (0)