Skip to content

Commit 0eee42a

Browse files
swolchokfacebook-github-bot
authored andcommitted
Don't require -march compiler flags to use bfdot (#5444)
Summary: Pull Request resolved: #5444 TIL about the `target` clang/GCC function attribute, which allows building a particular function under an `-march` flag instead of a whole file. ghstack-source-id: 243858419 Reviewed By: malfet Differential Revision: D62905047 fbshipit-source-id: a89c8169fea315aa653bbca819a672357c3dff77
1 parent c50f9fe commit 0eee42a

File tree

2 files changed

+76
-54
lines changed

2 files changed

+76
-54
lines changed

kernels/optimized/blas/BlasKernel.cpp

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -74,43 +74,60 @@ f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
7474
return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
7575
}
7676

77-
#ifdef __ARM_FEATURE_BF16
78-
static ET_INLINE float32x4_t
77+
#define ET_TARGET_ARM_BF16_ATTRIBUTE \
78+
__attribute__((target("arch=armv8.2-a+bf16")))
79+
ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t
7980
f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
8081
return vbfdotq_f32(a, b, c);
8182
}
82-
#endif // __ARM_FEATURE_BF16
8383

84-
template <bool useBfloat16Dot>
85-
static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
84+
ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
85+
dot_with_fp32_arith_main_inner_loop_bfdot(
86+
const BFloat16* vec1,
87+
const BFloat16* vec2,
88+
float32x4_t sum[kF32RegistersPerIteration],
89+
int registerPairIndex) {
90+
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
91+
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
92+
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
93+
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
94+
sum[registerPairIndex] =
95+
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
96+
}
97+
98+
static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
99+
const BFloat16* vec1,
100+
const BFloat16* vec2,
101+
float32x4_t sum[kF32RegistersPerIteration],
102+
int registerPairIndex) {
103+
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
104+
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
105+
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
106+
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
107+
108+
sum[2 * registerPairIndex] = f32_fma_bf16(
109+
sum[2 * registerPairIndex],
110+
vget_low_u16(temp_vec1),
111+
vget_low_u16(temp_vec2));
112+
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
113+
sum[2 * registerPairIndex + 1],
114+
vget_high_u16(temp_vec1),
115+
vget_high_u16(temp_vec2));
116+
}
117+
118+
template <bool useBfdot>
119+
ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
120+
dot_with_fp32_arith_main_inner_loop(
86121
const BFloat16* vec1,
87122
const BFloat16* vec2,
88123
float32x4_t sum[kF32RegistersPerIteration],
89124
int registerPairIndex) {
90-
#ifdef __ARM_FEATURE_BF16
91-
if (useBfloat16Dot) {
92-
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
93-
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
94-
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
95-
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
96-
sum[registerPairIndex] =
97-
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
98-
} else
99-
#endif // __ARM_FEATURE_BF16
100-
{
101-
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
102-
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
103-
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
104-
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
105-
106-
sum[2 * registerPairIndex] = f32_fma_bf16(
107-
sum[2 * registerPairIndex],
108-
vget_low_u16(temp_vec1),
109-
vget_low_u16(temp_vec2));
110-
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
111-
sum[2 * registerPairIndex + 1],
112-
vget_high_u16(temp_vec1),
113-
vget_high_u16(temp_vec2));
125+
if constexpr (useBfdot) {
126+
dot_with_fp32_arith_main_inner_loop_bfdot(
127+
vec1, vec2, sum, registerPairIndex);
128+
} else {
129+
dot_with_fp32_arith_main_inner_loop_no_bfdot(
130+
vec1, vec2, sum, registerPairIndex);
114131
}
115132
}
116133

@@ -126,18 +143,40 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
126143
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
127144
}
128145

129-
template <typename T, bool useBfloat16Dot>
130-
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
146+
namespace {
147+
template <int n>
148+
struct ForcedUnrollTargetBFloat16 {
149+
template <typename Func>
150+
ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
151+
ForcedUnrollTargetBFloat16<n - 1>{}(f);
152+
f(n - 1);
153+
}
154+
};
155+
156+
template <>
157+
struct ForcedUnrollTargetBFloat16<1> {
158+
template <typename Func>
159+
ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
160+
f(0);
161+
}
162+
};
163+
164+
} // namespace
165+
166+
template <typename T, bool useBFloat16Dot>
167+
ET_TARGET_ARM_BF16_ATTRIBUTE float
168+
dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
131169
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
132170
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
133171
for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
134172
const auto* vec1_ = vec1 + j;
135173
const auto* vec2_ = vec2 + j;
136-
utils::ForcedUnroll<kF32RegisterPairsPerIteration>{}(
137-
[vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
138-
dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
139-
vec1_, vec2_, sum, k);
140-
});
174+
ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}(
175+
[vec1_, vec2_, &sum](auto k)
176+
ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE {
177+
dot_with_fp32_arith_main_inner_loop<useBFloat16Dot>(
178+
vec1_, vec2_, sum, k);
179+
});
141180
}
142181
auto reducedSum = reduce(sum);
143182

@@ -163,12 +202,9 @@ float bf16_dot_with_fp32_arith(
163202
const BFloat16* vec1,
164203
const BFloat16* vec2,
165204
int64_t len) {
166-
#ifdef __ARM_FEATURE_BF16
167205
if (cpuinfo_has_arm_bf16()) {
168206
return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
169-
} else
170-
#endif // __ARM_FEATURE_BF16
171-
{
207+
} else {
172208
return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
173209
}
174210
}

kernels/optimized/lib_defs.bzl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,6 @@ def define_libs():
132132
] if not runtime.is_oss else [],
133133
"DEFAULT": [],
134134
}),
135-
fbandroid_platform_compiler_flags = [
136-
(
137-
"^android-arm64.*$",
138-
[
139-
"-march=armv8+bf16",
140-
],
141-
),
142-
],
143135
fbandroid_platform_preprocessor_flags = [
144136
(
145137
"^android-arm64.*$",
@@ -156,12 +148,6 @@ def define_libs():
156148
],
157149
),
158150
],
159-
fbobjc_platform_compiler_flags = [
160-
(
161-
".*arm64.*",
162-
["-march=armv8+bf16"],
163-
),
164-
],
165151
fbobjc_exported_preprocessor_flags = [
166152
"-DET_BUILD_WITH_BLAS",
167153
"-DET_BUILD_FOR_APPLE",

0 commit comments

Comments
 (0)