@@ -33,7 +33,8 @@ namespace executorch::cpublas::internal {
33
33
constexpr auto kF32RegisterPairsPerIteration = 4 ;
34
34
constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2 ;
35
35
constexpr auto kF32ElementsPerRegister = vec::Vectorized<float >::size();
36
- constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister ;
36
+ constexpr auto kF32ElementsPerIteration =
37
+ kF32RegistersPerIteration * kF32ElementsPerRegister ;
37
38
38
39
namespace {
39
40
template <typename T>
@@ -58,8 +59,8 @@ constexpr int IntegerLog2(T n, int p = 0) {
58
59
* copies of the Software, and to permit persons to whom the Software is
59
60
* furnished to do so, subject to the following conditions:
60
61
*
61
- * The above copyright notice and this permission notice shall be included in all
62
- * copies or substantial portions of the Software.
62
+ * The above copyright notice and this permission notice shall be included in
63
+ * all copies or substantial portions of the Software.
63
64
*
64
65
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
65
66
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
@@ -74,9 +75,7 @@ float reduce(vec::Vectorized<float> x) {
74
75
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
75
76
return vaddvq_f32 (x);
76
77
#else
77
- return vec::vec_reduce_all<float >(
78
- std::plus<vec::Vectorized<float >>(),
79
- x);
78
+ return vec::vec_reduce_all<float >(std::plus<vec::Vectorized<float >>(), x);
80
79
#endif
81
80
}
82
81
@@ -86,12 +85,13 @@ float reduce(vec::Vectorized<float> x) {
86
85
// required notice.
87
86
float reduce (vec::VectorizedN<float , kF32RegistersPerIteration >& x) {
88
87
int offset = kF32RegistersPerIteration ;
89
- c10::ForcedUnroll<IntegerLog2 (kF32RegistersPerIteration )>{}([&offset, &x](auto idx) {
90
- offset /= 2 ;
91
- for (const auto i : c10::irange (offset)) {
92
- x[i] = x[i] + x[offset + i];
93
- }
94
- });
88
+ c10::ForcedUnroll<IntegerLog2 (kF32RegistersPerIteration )>{}(
89
+ [&offset, &x](auto idx) {
90
+ offset /= 2 ;
91
+ for (const auto i : c10::irange (offset)) {
92
+ x[i] = x[i] + x[offset + i];
93
+ }
94
+ });
95
95
return reduce (x[0 ]);
96
96
}
97
97
@@ -102,16 +102,20 @@ float reduce(vec::VectorizedN<float, kF32RegistersPerIteration>& x) {
102
102
// We would have to write a separate SVE-specific path to use SVE
103
103
// BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path
104
104
// working.
105
- #if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
105
+ #if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \
106
+ defined (__clang__) && __clang_major__ > 15
106
107
// https://godbolt.org/z/z8P4Yncra
107
108
#define COMPILER_SUPPORTS_BF16_TARGET 1
108
- #elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
109
+ #elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \
110
+ !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
109
111
// https://gcc.gnu.org/gcc-10/changes.html
110
112
// https://godbolt.org/z/cdGG7vn8o
111
113
#define COMPILER_SUPPORTS_BF16_TARGET 1
112
- #else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
114
+ #else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) &&
115
+ // defined(__clang__) && __clang_major__ > 15
113
116
#define COMPILER_SUPPORTS_BF16_TARGET 0
114
- #endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
117
+ #endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) &&
118
+ // defined(__clang__) && __clang_major__ > 15
115
119
116
120
#if COMPILER_SUPPORTS_BF16_TARGET
117
121
#define TARGET_ARM_BF16_ATTRIBUTE __attribute__ ((target(" arch=armv8.2-a+bf16" )))
@@ -128,25 +132,25 @@ dot_with_fp32_arith_main_inner_loop_bfdot(
128
132
// bfloat16x8_t. I suspect a bug or incomplete
129
133
// __attribute__((target)) implementation. Intrinsics should be fine
130
134
// because we're using vbfdotq_f32 below anyway.
131
- const auto temp_vec1 = vld1q_bf16 (
132
- reinterpret_cast <const bfloat16_t *>(
133
- &vec1[registerPairIndex * vec::Vectorized<BFloat16>::size ()]));
134
- const auto temp_vec2 = vld1q_bf16 (
135
- reinterpret_cast <const bfloat16_t *>(
136
- &vec2[registerPairIndex * vec::Vectorized<BFloat16>::size ()]));
135
+ const auto temp_vec1 = vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(
136
+ &vec1[registerPairIndex * vec::Vectorized<BFloat16>::size ()]));
137
+ const auto temp_vec2 = vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(
138
+ &vec2[registerPairIndex * vec::Vectorized<BFloat16>::size ()]));
137
139
sum[registerPairIndex] =
138
- vbfdotq_f32 (sum[registerPairIndex], temp_vec1, temp_vec2);
140
+ vbfdotq_f32 (sum[registerPairIndex], temp_vec1, temp_vec2);
139
141
}
140
142
141
- TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE
142
- void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot (
143
+ TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void
144
+ dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot (
143
145
const at::BFloat16* vec1,
144
146
const at::BFloat16* vec2,
145
147
vec::Vectorized<float >* tail_sum,
146
148
int idx) {
147
149
// See NOTE[Intrinsics in bfdot variant] above.
148
- const auto temp_vec1 = vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(&vec1[idx]));
149
- const auto temp_vec2 = vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(&vec2[idx]));
150
+ const auto temp_vec1 =
151
+ vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(&vec1[idx]));
152
+ const auto temp_vec2 =
153
+ vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(&vec2[idx]));
150
154
*tail_sum = vbfdotq_f32 (*tail_sum, temp_vec1, temp_vec2);
151
155
}
152
156
@@ -156,14 +160,17 @@ void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot(
156
160
157
161
namespace {
158
162
159
- [[maybe_unused]] std::pair<vec::Vectorized<float >, vec::Vectorized<float >> fmadd (
163
+ [[maybe_unused]] std::pair<vec::Vectorized<float >, vec::Vectorized<float >>
164
+ fmadd (
160
165
const vec::Vectorized<c10::BFloat16>& a,
161
166
const vec::Vectorized<c10::BFloat16>& b,
162
167
const vec::Vectorized<float >& acc_low,
163
168
const vec::Vectorized<float >& acc_high) {
164
169
const auto [a_float_low, a_float_high] = convert_bfloat16_float (a);
165
170
const auto [b_float_low, b_float_high] = convert_bfloat16_float (b);
166
- return std::make_pair (fmadd (a_float_low, b_float_low, acc_low), fmadd (a_float_high, b_float_high, acc_high));
171
+ return std::make_pair (
172
+ fmadd (a_float_low, b_float_low, acc_low),
173
+ fmadd (a_float_high, b_float_high, acc_high));
167
174
}
168
175
169
176
[[maybe_unused]] vec::Vectorized<float > fmadd (
@@ -172,21 +179,28 @@ namespace {
172
179
const vec::Vectorized<c10::BFloat16>& b) {
173
180
const auto [a_float_low, a_float_high] = convert_bfloat16_float (a);
174
181
const auto [b_float_low, b_float_high] = convert_bfloat16_float (b);
175
- return fmadd (a_float_high, b_float_high, fmadd (a_float_low, b_float_low, acc));
182
+ return fmadd (
183
+ a_float_high, b_float_high, fmadd (a_float_low, b_float_low, acc));
176
184
}
177
185
} // namespace
178
186
179
187
template <typename T>
180
188
C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot (
181
- const T* vec1,
182
- const T* vec2,
183
- vec::VectorizedN<float , kF32RegistersPerIteration >& sum,
184
- int registerPairIndex) {
189
+ const T* vec1,
190
+ const T* vec2,
191
+ vec::VectorizedN<float , kF32RegistersPerIteration >& sum,
192
+ int registerPairIndex) {
185
193
static_assert (std::is_same_v<T, BFloat16>);
186
- const auto temp_vec1 = vec::Vectorized<T>::loadu (&vec1[registerPairIndex * vec::Vectorized<T>::size ()]);
187
- const auto temp_vec2 = vec::Vectorized<T>::loadu (&vec2[registerPairIndex * vec::Vectorized<T>::size ()]);
188
-
189
- const auto [result_low, result_high] = fmadd (temp_vec1, temp_vec2, sum[2 * registerPairIndex], sum[2 * registerPairIndex + 1 ]);
194
+ const auto temp_vec1 = vec::Vectorized<T>::loadu (
195
+ &vec1[registerPairIndex * vec::Vectorized<T>::size ()]);
196
+ const auto temp_vec2 = vec::Vectorized<T>::loadu (
197
+ &vec2[registerPairIndex * vec::Vectorized<T>::size ()]);
198
+
199
+ const auto [result_low, result_high] = fmadd (
200
+ temp_vec1,
201
+ temp_vec2,
202
+ sum[2 * registerPairIndex],
203
+ sum[2 * registerPairIndex + 1 ]);
190
204
sum[2 * registerPairIndex] = result_low;
191
205
sum[2 * registerPairIndex + 1 ] = result_high;
192
206
}
@@ -203,19 +217,19 @@ C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot(
203
217
}
204
218
205
219
template <typename T>
206
- C10_ALWAYS_INLINE auto
207
- dot_with_fp32_arith_main_loop_no_bfdot (
220
+ C10_ALWAYS_INLINE auto dot_with_fp32_arith_main_loop_no_bfdot (
208
221
const T* vec1,
209
222
const T* vec2,
210
223
int64_t len) {
211
224
vec::VectorizedN<float , kF32RegistersPerIteration > sum (0 );
212
225
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
213
- for (int j = 0 ; j < len_aligned ; j += kF32ElementsPerIteration ) {
226
+ for (int j = 0 ; j < len_aligned; j += kF32ElementsPerIteration ) {
214
227
const auto * vec1_ = vec1 + j;
215
228
const auto * vec2_ = vec2 + j;
216
- c10::ForcedUnroll<kF32RegisterPairsPerIteration >{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
217
- dot_with_fp32_arith_main_inner_loop_no_bfdot (vec1_, vec2_, sum, k);
218
- });
229
+ c10::ForcedUnroll<kF32RegisterPairsPerIteration >{}(
230
+ [vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
231
+ dot_with_fp32_arith_main_inner_loop_no_bfdot (vec1_, vec2_, sum, k);
232
+ });
219
233
}
220
234
return reduce (sum);
221
235
}
@@ -224,7 +238,8 @@ dot_with_fp32_arith_main_loop_no_bfdot(
224
238
template <int n>
225
239
struct ForcedUnrollTargetBFloat16 {
226
240
template <typename Func>
227
- TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(const Func& f) const {
241
+ TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(
242
+ const Func& f) const {
228
243
ForcedUnrollTargetBFloat16<n - 1 >{}(f);
229
244
f (n - 1 );
230
245
}
@@ -233,7 +248,8 @@ struct ForcedUnrollTargetBFloat16 {
233
248
template <>
234
249
struct ForcedUnrollTargetBFloat16 <1 > {
235
250
template <typename Func>
236
- TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(const Func& f) const {
251
+ TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(
252
+ const Func& f) const {
237
253
f (0 );
238
254
}
239
255
};
@@ -245,20 +261,22 @@ dot_with_fp32_arith_main_loop_bfdot(
245
261
int64_t len) {
246
262
vec::VectorizedN<float , kF32RegistersPerIteration > sum (0 );
247
263
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
248
- for (int j = 0 ; j < len_aligned ; j += kF32ElementsPerIteration ) {
264
+ for (int j = 0 ; j < len_aligned; j += kF32ElementsPerIteration ) {
249
265
const auto * vec1_ = vec1 + j;
250
266
const auto * vec2_ = vec2 + j;
251
- ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration >{}([vec1_, vec2_, &sum](auto k)
252
- C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
253
- dot_with_fp32_arith_main_inner_loop_bfdot (vec1_, vec2_, sum, k);
254
- });
267
+ ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration >{}(
268
+ [vec1_, vec2_, &sum](auto k)
269
+ C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
270
+ dot_with_fp32_arith_main_inner_loop_bfdot (vec1_, vec2_, sum, k);
271
+ });
255
272
}
256
273
return reduce (sum);
257
274
}
258
275
#endif // COMPILER_SUPPORTS_BF16_TARGET
259
276
260
277
static_assert (
261
- (vec::Vectorized<BFloat16>::size() & (vec::Vectorized<BFloat16>::size() - 1 )) == 0 ,
278
+ (vec::Vectorized<BFloat16>::size() &
279
+ (vec::Vectorized<BFloat16>::size() - 1 )) == 0 ,
262
280
" Below code expects power-of-2 vector register size!" );
263
281
264
282
// NOTE [GCC code duplication]: The first attempt at landing BFDOT support with
@@ -267,31 +285,35 @@ static_assert(
267
285
// function. We can work around this by duplicating the code into the
268
286
// bfdot and non-bfdot callsites. The code is in this macro to avoid
269
287
// actual copy/paste.
270
- #define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY (bfdot_suffix ) \
271
- /* First-tier tail fixup: make sure we handle workloads that can */ \
272
- /* benefit from vectorization, but don't fit into our fully unrolled */ \
273
- /* loop above. */ \
274
- vec::Vectorized<float > tail_sum (0 ); \
275
- const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 ); \
288
+ #define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY (bfdot_suffix ) \
289
+ /* First-tier tail fixup: make sure we handle workloads that can */ \
290
+ /* benefit from vectorization, but don't fit into our fully unrolled */ \
291
+ /* loop above. */ \
292
+ vec::Vectorized<float > tail_sum (0 ); \
293
+ const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 ); \
276
294
const auto len_aligned_vec = len & ~(vec::Vectorized<BFloat16>::size() - 1 ); \
277
- for (int j = len_aligned; j < len_aligned_vec; j += vec::Vectorized<BFloat16>::size()) { \
278
- dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix (vec1, vec2, &tail_sum, j); \
279
- } \
280
- reduced_sum += reduce(tail_sum); \
281
- \
282
- /* Second-tier tail fixup: handle all workloads. */ \
283
- for (const auto j : c10::irange(len_aligned_vec, len)) { \
284
- /* Attempting to use Half here caused multiple test failures; */ \
285
- /* using float to unbreak. (Suspect we need a scalar FMA.) */ \
286
- float x1 = vec1[j]; \
287
- float x2 = vec2[j]; \
288
- reduced_sum += x1 * x2; \
289
- } \
295
+ for (int j = len_aligned; j < len_aligned_vec; \
296
+ j += vec::Vectorized<BFloat16>::size()) { \
297
+ dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix ( \
298
+ vec1, vec2, &tail_sum, j); \
299
+ } \
300
+ reduced_sum += reduce(tail_sum); \
301
+ \
302
+ /* Second-tier tail fixup: handle all workloads. */ \
303
+ for (const auto j : c10::irange(len_aligned_vec, len)) { \
304
+ /* Attempting to use Half here caused multiple test failures; */ \
305
+ /* using float to unbreak. (Suspect we need a scalar FMA.) */ \
306
+ float x1 = vec1[j]; \
307
+ float x2 = vec2[j]; \
308
+ reduced_sum += x1 * x2; \
309
+ } \
290
310
return reduced_sum
291
311
292
312
#if COMPILER_SUPPORTS_BF16_TARGET
293
- TARGET_ARM_BF16_ATTRIBUTE float
294
- dot_with_fp32_arith_bfdot (const BFloat16* vec1, const BFloat16* vec2, int64_t len) {
313
+ TARGET_ARM_BF16_ATTRIBUTE float dot_with_fp32_arith_bfdot (
314
+ const BFloat16* vec1,
315
+ const BFloat16* vec2,
316
+ int64_t len) {
295
317
auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot (vec1, vec2, len);
296
318
DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY (_bfdot);
297
319
}
@@ -307,7 +329,10 @@ dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) {
307
329
308
330
} // namespace
309
331
310
- float bf16_dot_with_fp32_arith (const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) {
332
+ float bf16_dot_with_fp32_arith (
333
+ const at::BFloat16* vec1,
334
+ const at::BFloat16* vec2,
335
+ int64_t len) {
311
336
#if COMPILER_SUPPORTS_BF16_TARGET
312
337
if (cpuinfo_has_arm_bf16 ()) {
313
338
return dot_with_fp32_arith_bfdot (vec1, vec2, len);
0 commit comments