Skip to content

Commit bbe0ebd

Browse files
pytorchbotswolchok
andauthored
[ExecuTorch] Reapply D62466496: Build optimized kernels with bf16 support and gate usage at runtime (#5420)
Reapply D62466496: Build optimized kernels with bf16 support and gate usage at runtime (#5376) Summary: Pull Request resolved: #5376 Now with fewer broken tests. ghstack-source-id: 242772181 Reviewed By: kimishpatel Differential Revision: D62680594 fbshipit-source-id: 517791f303165423977593631e93368b95864e95 (cherry picked from commit 2b3cc27) Co-authored-by: Scott Wolchok <[email protected]>
1 parent a5d6789 commit bbe0ebd

File tree

4 files changed

+61
-33
lines changed

4 files changed

+61
-33
lines changed

kernels/optimized/blas/BlasKernel.cpp

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#ifdef __aarch64__
1212
#include <arm_neon.h>
13+
#include <cpuinfo.h>
1314
#endif
1415

1516
using torch::executor::BFloat16;
@@ -23,7 +24,7 @@ static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
2324
return vfmaq_f32(a, b, c);
2425
#else
2526
return vaddq_f32(a, vmulq_f32(b, c));
26-
#endif
27+
#endif // __ARM_FEATURE_FMA
2728
}
2829

2930
// The below reduce overload and fp16_dot_with_fp32_arith are adapted
@@ -78,35 +79,39 @@ static ET_INLINE float32x4_t
7879
f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
7980
return vbfdotq_f32(a, b, c);
8081
}
81-
#endif
82+
#endif // __ARM_FEATURE_BF16
8283

84+
template <bool useBfloat16Dot>
8385
static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
8486
const BFloat16* vec1,
8587
const BFloat16* vec2,
8688
float32x4_t sum[kF32RegistersPerIteration],
8789
int registerPairIndex) {
8890
#ifdef __ARM_FEATURE_BF16
89-
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
90-
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
91-
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
92-
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
93-
sum[registerPairIndex] =
94-
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
95-
#else
96-
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
97-
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
98-
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
99-
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
100-
101-
sum[2 * registerPairIndex] = f32_fma_bf16(
102-
sum[2 * registerPairIndex],
103-
vget_low_u16(temp_vec1),
104-
vget_low_u16(temp_vec2));
105-
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
106-
sum[2 * registerPairIndex + 1],
107-
vget_high_u16(temp_vec1),
108-
vget_high_u16(temp_vec2));
109-
#endif
91+
if (useBfloat16Dot) {
92+
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
93+
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
94+
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
95+
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
96+
sum[registerPairIndex] =
97+
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
98+
} else
99+
#endif // __ARM_FEATURE_BF16
100+
{
101+
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
102+
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
103+
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
104+
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
105+
106+
sum[2 * registerPairIndex] = f32_fma_bf16(
107+
sum[2 * registerPairIndex],
108+
vget_low_u16(temp_vec1),
109+
vget_low_u16(temp_vec2));
110+
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
111+
sum[2 * registerPairIndex + 1],
112+
vget_high_u16(temp_vec1),
113+
vget_high_u16(temp_vec2));
114+
}
110115
}
111116

112117
static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
@@ -121,7 +126,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
121126
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
122127
}
123128

124-
template <typename T>
129+
template <typename T, bool useBfloat16Dot>
125130
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
126131
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
127132
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
@@ -130,7 +135,8 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
130135
const auto* vec2_ = vec2 + j;
131136
utils::ForcedUnroll<kF32RegisterPairsPerIteration>{}(
132137
[vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
133-
dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
138+
dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
139+
vec1_, vec2_, sum, k);
134140
});
135141
}
136142
auto reducedSum = reduce(sum);
@@ -157,9 +163,16 @@ float bf16_dot_with_fp32_arith(
157163
const BFloat16* vec1,
158164
const BFloat16* vec2,
159165
int64_t len) {
160-
return dot_with_fp32_arith(vec1, vec2, len);
166+
#ifdef __ARM_FEATURE_BF16
167+
if (cpuinfo_has_arm_bf16()) {
168+
return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
169+
} else
170+
#endif // __ARM_FEATURE_BF16
171+
{
172+
return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
173+
}
161174
}
162-
#endif
175+
#endif // __aarch64__
163176
} // namespace internal
164177
} // namespace cpublas
165178
} // namespace executorch

kernels/optimized/lib_defs.bzl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
load("@fbsource//tools/build_defs:default_platform_defs.bzl", "DEVSERVER_PLATFORM_REGEX")
22
load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
3+
load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep")
34
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
45

56
# Because vec exists as a collection of header files, compile and preprocessor
@@ -109,6 +110,8 @@ def define_libs():
109110
],
110111
)
111112

113+
LIBBLAS_DEPS = [third_party_dep("cpuinfo")]
114+
112115
for libblas_name, mkl_dep in [("libblas", "fbsource//third-party/mkl:mkl_lp64_omp"), ("libblas_mkl_noomp", "fbsource//third-party/mkl:mkl")]:
113116
runtime.cxx_library(
114117
name = libblas_name,
@@ -129,6 +132,14 @@ def define_libs():
129132
] if not runtime.is_oss else [],
130133
"DEFAULT": [],
131134
}),
135+
fbandroid_platform_compiler_flags = [
136+
(
137+
"^android-arm64.*$",
138+
[
139+
"-march=armv8+bf16",
140+
],
141+
),
142+
],
132143
fbandroid_platform_preprocessor_flags = [
133144
(
134145
"^android-arm64.*$",
@@ -145,6 +156,9 @@ def define_libs():
145156
],
146157
),
147158
],
159+
fbobjc_compiler_flags = [
160+
"-march=armv8+bf16",
161+
],
148162
fbobjc_exported_preprocessor_flags = [
149163
"-DET_BUILD_WITH_BLAS",
150164
"-DET_BUILD_FOR_APPLE",
@@ -155,7 +169,7 @@ def define_libs():
155169
deps = select({
156170
":linux-x86_64": [mkl_dep] if not runtime.is_oss else [],
157171
"DEFAULT": [],
158-
}),
172+
}) + LIBBLAS_DEPS,
159173
exported_deps = [
160174
"//executorch/extension/parallel:thread_parallel",
161175
"//executorch/kernels/optimized:libutils",

kernels/test/op_linear_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,16 @@ class OpLinearOutTest : public OperatorTest {
4343
}
4444
}
4545

46-
// matmul gives 4 * 2 * 3 = 24
47-
Tensor x = tf.full({3, 4}, 2);
48-
Tensor y = tf.full({5, 4}, 3);
46+
// matmul gives 32 * 2 * 3 = 192
47+
Tensor x = tf.full({3, 32}, 2);
48+
Tensor y = tf.full({5, 32}, 3);
4949

5050
// Output shape should be (3, 5)
5151
Tensor out = tf.zeros({3, 5});
5252

5353
op_linear_out(x, y, out);
5454

55-
Tensor expected = tf.full({3, 5}, 24);
55+
Tensor expected = tf.full({3, 5}, 192);
5656

5757
EXPECT_TENSOR_EQ(out, expected);
5858
}

shim/xplat/executorch/build/env_interface.bzl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def _remove_platform_specific_args(kwargs):
118118
"""
119119
keys = []
120120
for key in kwargs:
121-
if key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or key.startswith("fbobjc"):
121+
if (key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or
122+
key.startswith("fbobjc") or key.endswith("_platform_compiler_flags")):
122123
keys.append(key)
123124
for key in keys:
124125
kwargs.pop(key)

0 commit comments

Comments
 (0)