10
10
11
11
#ifdef __aarch64__
12
12
#include < arm_neon.h>
13
+ #include < cpuinfo.h>
13
14
#endif
14
15
15
16
using torch::executor::BFloat16;
@@ -23,7 +24,7 @@ static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
23
24
return vfmaq_f32 (a, b, c);
24
25
#else
25
26
return vaddq_f32 (a, vmulq_f32 (b, c));
26
- #endif
27
+ #endif // __ARM_FEATURE_FMA
27
28
}
28
29
29
30
// The below reduce overload and fp16_dot_with_fp32_arith are adapted
@@ -78,35 +79,39 @@ static ET_INLINE float32x4_t
78
79
f32_dot_bf16 (float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
79
80
return vbfdotq_f32 (a, b, c);
80
81
}
81
- #endif
82
+ #endif // __ARM_FEATURE_BF16
82
83
84
+ template <bool useBfloat16Dot>
83
85
static ET_INLINE void dot_with_fp32_arith_main_inner_loop (
84
86
const BFloat16* vec1,
85
87
const BFloat16* vec2,
86
88
float32x4_t sum[kF32RegistersPerIteration ],
87
89
int registerPairIndex) {
88
90
#ifdef __ARM_FEATURE_BF16
89
- const bfloat16x8_t temp_vec1 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
90
- &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
91
- const bfloat16x8_t temp_vec2 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
92
- &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
93
- sum[registerPairIndex] =
94
- f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
95
- #else
96
- const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
97
- &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
98
- const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
99
- &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
100
-
101
- sum[2 * registerPairIndex] = f32_fma_bf16 (
102
- sum[2 * registerPairIndex],
103
- vget_low_u16 (temp_vec1),
104
- vget_low_u16 (temp_vec2));
105
- sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (
106
- sum[2 * registerPairIndex + 1 ],
107
- vget_high_u16 (temp_vec1),
108
- vget_high_u16 (temp_vec2));
109
- #endif
91
+ if (useBfloat16Dot) {
92
+ const bfloat16x8_t temp_vec1 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
93
+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
94
+ const bfloat16x8_t temp_vec2 = vld1q_bf16 (reinterpret_cast <const __bf16*>(
95
+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
96
+ sum[registerPairIndex] =
97
+ f32_dot_bf16 (sum[registerPairIndex], temp_vec1, temp_vec2);
98
+ } else
99
+ #endif // __ARM_FEATURE_BF16
100
+ {
101
+ const uint16x8_t temp_vec1 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
102
+ &vec1[registerPairIndex * 2 * kF32ElementsPerRegister ]));
103
+ const uint16x8_t temp_vec2 = vld1q_u16 (reinterpret_cast <const uint16_t *>(
104
+ &vec2[registerPairIndex * 2 * kF32ElementsPerRegister ]));
105
+
106
+ sum[2 * registerPairIndex] = f32_fma_bf16 (
107
+ sum[2 * registerPairIndex],
108
+ vget_low_u16 (temp_vec1),
109
+ vget_low_u16 (temp_vec2));
110
+ sum[2 * registerPairIndex + 1 ] = f32_fma_bf16 (
111
+ sum[2 * registerPairIndex + 1 ],
112
+ vget_high_u16 (temp_vec1),
113
+ vget_high_u16 (temp_vec2));
114
+ }
110
115
}
111
116
112
117
static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop (
@@ -121,7 +126,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
121
126
*tailSum = f32_fma_bf16 (*tailSum, temp_vec1, temp_vec2);
122
127
}
123
128
124
- template <typename T>
129
+ template <typename T, bool useBfloat16Dot >
125
130
float dot_with_fp32_arith (const T* vec1, const T* vec2, int64_t len) {
126
131
float32x4_t sum[kF32RegistersPerIteration ] = {vdupq_n_f32 (0 )};
127
132
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
@@ -130,7 +135,8 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
130
135
const auto * vec2_ = vec2 + j;
131
136
utils::ForcedUnroll<kF32RegisterPairsPerIteration >{}(
132
137
[vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
133
- dot_with_fp32_arith_main_inner_loop (vec1_, vec2_, sum, k);
138
+ dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
139
+ vec1_, vec2_, sum, k);
134
140
});
135
141
}
136
142
auto reducedSum = reduce (sum);
@@ -157,9 +163,16 @@ float bf16_dot_with_fp32_arith(
157
163
const BFloat16* vec1,
158
164
const BFloat16* vec2,
159
165
int64_t len) {
160
- return dot_with_fp32_arith (vec1, vec2, len);
166
+ #ifdef __ARM_FEATURE_BF16
167
+ if (cpuinfo_has_arm_bf16 ()) {
168
+ return dot_with_fp32_arith<BFloat16, true >(vec1, vec2, len);
169
+ } else
170
+ #endif // __ARM_FEATURE_BF16
171
+ {
172
+ return dot_with_fp32_arith<BFloat16, false >(vec1, vec2, len);
173
+ }
161
174
}
162
- #endif
175
+ #endif // __aarch64__
163
176
} // namespace internal
164
177
} // namespace cpublas
165
178
} // namespace executorch
0 commit comments