Skip to content

Commit 294798b

Browse files
committed
lintrunner on "Mostly sync BlasKernel.cpp with ATen ReducedPrecisionGemvFastPathKernel"
The two files were similar, but diverged due to recent changes. Since we have sharing of PyTorch headers, we can keep them mostly the same; differences are some of the namespace stuff and a couple of EXECUTORCH NOTEs. Differential Revision: [D74702689](https://our.internmc.facebook.com/intern/diff/D74702689/) [ghstack-poisoned]
1 parent 4a91e45 commit 294798b

File tree

1 file changed

+99
-74
lines changed

1 file changed

+99
-74
lines changed

kernels/optimized/blas/BlasKernel.cpp

Lines changed: 99 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ namespace executorch::cpublas::internal {
3333
constexpr auto kF32RegisterPairsPerIteration = 4;
3434
constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
3535
constexpr auto kF32ElementsPerRegister = vec::Vectorized<float>::size();
36-
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;
36+
constexpr auto kF32ElementsPerIteration =
37+
kF32RegistersPerIteration * kF32ElementsPerRegister;
3738

3839
namespace {
3940
template <typename T>
@@ -58,8 +59,8 @@ constexpr int IntegerLog2(T n, int p = 0) {
5859
* copies of the Software, and to permit persons to whom the Software is
5960
* furnished to do so, subject to the following conditions:
6061
*
61-
* The above copyright notice and this permission notice shall be included in all
62-
* copies or substantial portions of the Software.
62+
* The above copyright notice and this permission notice shall be included in
63+
* all copies or substantial portions of the Software.
6364
*
6465
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
6566
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
@@ -74,9 +75,7 @@ float reduce(vec::Vectorized<float> x) {
7475
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
7576
return vaddvq_f32(x);
7677
#else
77-
return vec::vec_reduce_all<float>(
78-
std::plus<vec::Vectorized<float>>(),
79-
x);
78+
return vec::vec_reduce_all<float>(std::plus<vec::Vectorized<float>>(), x);
8079
#endif
8180
}
8281

@@ -86,12 +85,13 @@ float reduce(vec::Vectorized<float> x) {
8685
// required notice.
8786
float reduce(vec::VectorizedN<float, kF32RegistersPerIteration>& x) {
8887
int offset = kF32RegistersPerIteration;
89-
c10::ForcedUnroll<IntegerLog2(kF32RegistersPerIteration)>{}([&offset, &x](auto idx) {
90-
offset /= 2;
91-
for (const auto i : c10::irange(offset)) {
92-
x[i] = x[i] + x[offset + i];
93-
}
94-
});
88+
c10::ForcedUnroll<IntegerLog2(kF32RegistersPerIteration)>{}(
89+
[&offset, &x](auto idx) {
90+
offset /= 2;
91+
for (const auto i : c10::irange(offset)) {
92+
x[i] = x[i] + x[offset + i];
93+
}
94+
});
9595
return reduce(x[0]);
9696
}
9797

@@ -102,16 +102,20 @@ float reduce(vec::VectorizedN<float, kF32RegistersPerIteration>& x) {
102102
// We would have to write a separate SVE-specific path to use SVE
103103
// BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path
104104
// working.
105-
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
105+
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \
106+
defined(__clang__) && __clang_major__ > 15
106107
// https://godbolt.org/z/z8P4Yncra
107108
#define COMPILER_SUPPORTS_BF16_TARGET 1
108-
#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
109+
#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \
110+
!defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
109111
// https://gcc.gnu.org/gcc-10/changes.html
110112
// https://godbolt.org/z/cdGG7vn8o
111113
#define COMPILER_SUPPORTS_BF16_TARGET 1
112-
#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
114+
#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) &&
115+
// defined(__clang__) && __clang_major__ > 15
113116
#define COMPILER_SUPPORTS_BF16_TARGET 0
114-
#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
117+
#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) &&
118+
// defined(__clang__) && __clang_major__ > 15
115119

116120
#if COMPILER_SUPPORTS_BF16_TARGET
117121
#define TARGET_ARM_BF16_ATTRIBUTE __attribute__((target("arch=armv8.2-a+bf16")))
@@ -128,25 +132,25 @@ dot_with_fp32_arith_main_inner_loop_bfdot(
128132
// bfloat16x8_t. I suspect a bug or incomplete
129133
// __attribute__((target)) implementation. Intrinsics should be fine
130134
// because we're using vbfdotq_f32 below anyway.
131-
const auto temp_vec1 = vld1q_bf16(
132-
reinterpret_cast<const bfloat16_t*>(
133-
&vec1[registerPairIndex * vec::Vectorized<BFloat16>::size()]));
134-
const auto temp_vec2 = vld1q_bf16(
135-
reinterpret_cast<const bfloat16_t*>(
136-
&vec2[registerPairIndex * vec::Vectorized<BFloat16>::size()]));
135+
const auto temp_vec1 = vld1q_bf16(reinterpret_cast<const bfloat16_t*>(
136+
&vec1[registerPairIndex * vec::Vectorized<BFloat16>::size()]));
137+
const auto temp_vec2 = vld1q_bf16(reinterpret_cast<const bfloat16_t*>(
138+
&vec2[registerPairIndex * vec::Vectorized<BFloat16>::size()]));
137139
sum[registerPairIndex] =
138-
vbfdotq_f32(sum[registerPairIndex], temp_vec1, temp_vec2);
140+
vbfdotq_f32(sum[registerPairIndex], temp_vec1, temp_vec2);
139141
}
140142

141-
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE
142-
void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot(
143+
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void
144+
dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot(
143145
const at::BFloat16* vec1,
144146
const at::BFloat16* vec2,
145147
vec::Vectorized<float>* tail_sum,
146148
int idx) {
147149
// See NOTE[Intrinsics in bfdot variant] above.
148-
const auto temp_vec1 = vld1q_bf16(reinterpret_cast<const bfloat16_t*>(&vec1[idx]));
149-
const auto temp_vec2 = vld1q_bf16(reinterpret_cast<const bfloat16_t*>(&vec2[idx]));
150+
const auto temp_vec1 =
151+
vld1q_bf16(reinterpret_cast<const bfloat16_t*>(&vec1[idx]));
152+
const auto temp_vec2 =
153+
vld1q_bf16(reinterpret_cast<const bfloat16_t*>(&vec2[idx]));
150154
*tail_sum = vbfdotq_f32(*tail_sum, temp_vec1, temp_vec2);
151155
}
152156

@@ -156,14 +160,17 @@ void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot(
156160

157161
namespace {
158162

159-
[[maybe_unused]] std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
163+
[[maybe_unused]] std::pair<vec::Vectorized<float>, vec::Vectorized<float>>
164+
fmadd(
160165
const vec::Vectorized<c10::BFloat16>& a,
161166
const vec::Vectorized<c10::BFloat16>& b,
162167
const vec::Vectorized<float>& acc_low,
163168
const vec::Vectorized<float>& acc_high) {
164169
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
165170
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
166-
return std::make_pair(fmadd(a_float_low, b_float_low, acc_low), fmadd(a_float_high, b_float_high, acc_high));
171+
return std::make_pair(
172+
fmadd(a_float_low, b_float_low, acc_low),
173+
fmadd(a_float_high, b_float_high, acc_high));
167174
}
168175

169176
[[maybe_unused]] vec::Vectorized<float> fmadd(
@@ -172,21 +179,28 @@ namespace {
172179
const vec::Vectorized<c10::BFloat16>& b) {
173180
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
174181
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
175-
return fmadd(a_float_high, b_float_high, fmadd(a_float_low, b_float_low, acc));
182+
return fmadd(
183+
a_float_high, b_float_high, fmadd(a_float_low, b_float_low, acc));
176184
}
177185
} // namespace
178186

179187
template <typename T>
180188
C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
181-
const T* vec1,
182-
const T* vec2,
183-
vec::VectorizedN<float, kF32RegistersPerIteration>& sum,
184-
int registerPairIndex) {
189+
const T* vec1,
190+
const T* vec2,
191+
vec::VectorizedN<float, kF32RegistersPerIteration>& sum,
192+
int registerPairIndex) {
185193
static_assert(std::is_same_v<T, BFloat16>);
186-
const auto temp_vec1 = vec::Vectorized<T>::loadu(&vec1[registerPairIndex * vec::Vectorized<T>::size()]);
187-
const auto temp_vec2 = vec::Vectorized<T>::loadu(&vec2[registerPairIndex * vec::Vectorized<T>::size()]);
188-
189-
const auto [result_low, result_high] = fmadd(temp_vec1, temp_vec2, sum[2 * registerPairIndex], sum[2 * registerPairIndex + 1]);
194+
const auto temp_vec1 = vec::Vectorized<T>::loadu(
195+
&vec1[registerPairIndex * vec::Vectorized<T>::size()]);
196+
const auto temp_vec2 = vec::Vectorized<T>::loadu(
197+
&vec2[registerPairIndex * vec::Vectorized<T>::size()]);
198+
199+
const auto [result_low, result_high] = fmadd(
200+
temp_vec1,
201+
temp_vec2,
202+
sum[2 * registerPairIndex],
203+
sum[2 * registerPairIndex + 1]);
190204
sum[2 * registerPairIndex] = result_low;
191205
sum[2 * registerPairIndex + 1] = result_high;
192206
}
@@ -203,19 +217,19 @@ C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot(
203217
}
204218

205219
template <typename T>
206-
C10_ALWAYS_INLINE auto
207-
dot_with_fp32_arith_main_loop_no_bfdot(
220+
C10_ALWAYS_INLINE auto dot_with_fp32_arith_main_loop_no_bfdot(
208221
const T* vec1,
209222
const T* vec2,
210223
int64_t len) {
211224
vec::VectorizedN<float, kF32RegistersPerIteration> sum(0);
212225
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
213-
for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
226+
for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
214227
const auto* vec1_ = vec1 + j;
215228
const auto* vec2_ = vec2 + j;
216-
c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
217-
dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k);
218-
});
229+
c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}(
230+
[vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
231+
dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k);
232+
});
219233
}
220234
return reduce(sum);
221235
}
@@ -224,7 +238,8 @@ dot_with_fp32_arith_main_loop_no_bfdot(
224238
template <int n>
225239
struct ForcedUnrollTargetBFloat16 {
226240
template <typename Func>
227-
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const {
241+
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(
242+
const Func& f) const {
228243
ForcedUnrollTargetBFloat16<n - 1>{}(f);
229244
f(n - 1);
230245
}
@@ -233,7 +248,8 @@ struct ForcedUnrollTargetBFloat16 {
233248
template <>
234249
struct ForcedUnrollTargetBFloat16<1> {
235250
template <typename Func>
236-
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const {
251+
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(
252+
const Func& f) const {
237253
f(0);
238254
}
239255
};
@@ -245,20 +261,22 @@ dot_with_fp32_arith_main_loop_bfdot(
245261
int64_t len) {
246262
vec::VectorizedN<float, kF32RegistersPerIteration> sum(0);
247263
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
248-
for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
264+
for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
249265
const auto* vec1_ = vec1 + j;
250266
const auto* vec2_ = vec2 + j;
251-
ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k)
252-
C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
253-
dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k);
254-
});
267+
ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}(
268+
[vec1_, vec2_, &sum](auto k)
269+
C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
270+
dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k);
271+
});
255272
}
256273
return reduce(sum);
257274
}
258275
#endif // COMPILER_SUPPORTS_BF16_TARGET
259276

260277
static_assert(
261-
(vec::Vectorized<BFloat16>::size() & (vec::Vectorized<BFloat16>::size() - 1)) == 0,
278+
(vec::Vectorized<BFloat16>::size() &
279+
(vec::Vectorized<BFloat16>::size() - 1)) == 0,
262280
"Below code expects power-of-2 vector register size!");
263281

264282
// NOTE [GCC code duplication]: The first attempt at landing BFDOT support with
@@ -267,31 +285,35 @@ static_assert(
267285
// function. We can work around this by duplicating the code into the
268286
// bfdot and non-bfdot callsites. The code is in this macro to avoid
269287
// actual copy/paste.
270-
#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \
271-
/* First-tier tail fixup: make sure we handle workloads that can */ \
272-
/* benefit from vectorization, but don't fit into our fully unrolled */ \
273-
/* loop above. */ \
274-
vec::Vectorized<float> tail_sum(0); \
275-
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \
288+
#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \
289+
/* First-tier tail fixup: make sure we handle workloads that can */ \
290+
/* benefit from vectorization, but don't fit into our fully unrolled */ \
291+
/* loop above. */ \
292+
vec::Vectorized<float> tail_sum(0); \
293+
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \
276294
const auto len_aligned_vec = len & ~(vec::Vectorized<BFloat16>::size() - 1); \
277-
for (int j = len_aligned; j < len_aligned_vec; j += vec::Vectorized<BFloat16>::size()) { \
278-
dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix(vec1, vec2, &tail_sum, j); \
279-
} \
280-
reduced_sum += reduce(tail_sum); \
281-
\
282-
/* Second-tier tail fixup: handle all workloads. */ \
283-
for (const auto j : c10::irange(len_aligned_vec, len)) { \
284-
/* Attempting to use Half here caused multiple test failures; */ \
285-
/* using float to unbreak. (Suspect we need a scalar FMA.) */ \
286-
float x1 = vec1[j]; \
287-
float x2 = vec2[j]; \
288-
reduced_sum += x1 * x2; \
289-
} \
295+
for (int j = len_aligned; j < len_aligned_vec; \
296+
j += vec::Vectorized<BFloat16>::size()) { \
297+
dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix( \
298+
vec1, vec2, &tail_sum, j); \
299+
} \
300+
reduced_sum += reduce(tail_sum); \
301+
\
302+
/* Second-tier tail fixup: handle all workloads. */ \
303+
for (const auto j : c10::irange(len_aligned_vec, len)) { \
304+
/* Attempting to use Half here caused multiple test failures; */ \
305+
/* using float to unbreak. (Suspect we need a scalar FMA.) */ \
306+
float x1 = vec1[j]; \
307+
float x2 = vec2[j]; \
308+
reduced_sum += x1 * x2; \
309+
} \
290310
return reduced_sum
291311

292312
#if COMPILER_SUPPORTS_BF16_TARGET
293-
TARGET_ARM_BF16_ATTRIBUTE float
294-
dot_with_fp32_arith_bfdot(const BFloat16* vec1, const BFloat16* vec2, int64_t len) {
313+
TARGET_ARM_BF16_ATTRIBUTE float dot_with_fp32_arith_bfdot(
314+
const BFloat16* vec1,
315+
const BFloat16* vec2,
316+
int64_t len) {
295317
auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot(vec1, vec2, len);
296318
DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_bfdot);
297319
}
@@ -307,7 +329,10 @@ dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) {
307329

308330
} // namespace
309331

310-
float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) {
332+
float bf16_dot_with_fp32_arith(
333+
const at::BFloat16* vec1,
334+
const at::BFloat16* vec2,
335+
int64_t len) {
311336
#if COMPILER_SUPPORTS_BF16_TARGET
312337
if (cpuinfo_has_arm_bf16()) {
313338
return dot_with_fp32_arith_bfdot(vec1, vec2, len);

0 commit comments

Comments
 (0)