Skip to content

Commit a26d2bf

Browse files
committed
Reuse GELU implementation from PyTorch core
Pull Request resolved: #7041 kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. ghstack-source-id: 258553927 @exported-using-ghexport Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/)
1 parent a25444c commit a26d2bf

File tree

4 files changed

+70
-42
lines changed

4 files changed

+70
-42
lines changed

kernels/optimized/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ message("Generated files ${gen_command_sources}")
6060

6161
list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
6262
add_library(optimized_kernels ${_optimized_kernels__srcs})
63+
find_package(Torch CONFIG REQUIRED)
64+
target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS})
6365
target_link_libraries(
6466
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
6567
)

kernels/optimized/cpu/op_gelu.cpp

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include <cmath>
1515

16+
#include <ATen/native/cpu/Gelu.h>
1617
#include <executorch/runtime/kernel/kernel_includes.h>
1718
#include <executorch/runtime/platform/assert.h>
1819

@@ -46,48 +47,26 @@ void gelu(
4647
CTYPE* out_data = output.mutable_data_ptr<CTYPE>();
4748
size_t lim = input.numel();
4849

49-
// TODO: Add fast path for tanh using sleef's tanh
5050
if (approximate == "tanh") {
51-
// 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))
52-
for (size_t i = 0; i < lim; ++i) {
53-
const CTYPE x = in_data[i];
54-
const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
55-
const CTYPE kKappa = 0.044715;
56-
auto x_cube = x * x * x;
57-
auto inner = kBeta * (x + kKappa * x_cube);
58-
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner));
51+
using Vec = at::vec::Vectorized<CTYPE>;
52+
int i = 0;
53+
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
54+
Vec x = Vec::loadu(in_data + i);
55+
at::native::vectorized_gelu_approximated_with_tanh(x).store(out_data + i);
5956
}
60-
} else if (approximate == "none") { // dont appx
61-
// GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution
62-
// Function for Gaussian Distribution.
63-
64-
#ifndef __aarch64__
65-
for (size_t i = 0; i < lim; ++i) {
66-
const CTYPE x = in_data[i];
67-
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
57+
for (; i < lim; ++i) {
58+
out_data[i] = at::native::scalar_gelu_approximated_with_tanh(in_data[i]);
6859
}
69-
#else
70-
size_t i = 0;
71-
if (std::is_same<CTYPE, float>::value) {
72-
for (; i + 4 < lim; i += 4) {
73-
const float32x4_t in =
74-
vld1q_f32(static_cast<const float*>(&in_data[i]));
75-
const float32x4_t m_sqrt1_2x4 = {
76-
M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2};
77-
const float32x4_t ones = vmovq_n_f32(1.0);
78-
const float32x4_t halves = vmovq_n_f32(0.5);
79-
float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4));
80-
vst1q_f32(
81-
static_cast<float*>(&out_data[i]),
82-
vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves));
83-
}
60+
} else if (approximate == "none") {
61+
using Vec = at::vec::Vectorized<CTYPE>;
62+
int i = 0;
63+
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
64+
Vec x = Vec::loadu(in_data + i);
65+
at::native::vectorized_gelu(x).store(out_data + i);
8466
}
8567
for (; i < lim; ++i) {
86-
const CTYPE x = in_data[i];
87-
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
68+
out_data[i] = at::native::scalar_gelu(in_data[i]);
8869
}
89-
#endif // __aarch64__
90-
9170
} else {
9271
ET_KERNEL_CHECK_MSG(
9372
context,

kernels/optimized/cpu/targets.bzl

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ _OPTIMIZED_ATEN_OPS = (
2828
op_target(name = "op_sigmoid"),
2929
op_target(
3030
name = "op_gelu",
31-
deps = select({
32-
"DEFAULT": [],
33-
"ovr_config//cpu:arm64": [
34-
"fbsource//third-party/sleef:sleef_arm",
35-
],
36-
}),
31+
deps = [
32+
":aten_headers_for_executorch",
33+
],
3734
),
3835
op_target(
3936
name = "op_le",
@@ -94,6 +91,13 @@ _OPTIMIZED_ATEN_OPS = (
9491
),
9592
)
9693

94+
95+
def get_sleef_preprocessor_flags():
96+
if runtime.is_oss:
97+
return []
98+
return ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"]
99+
100+
97101
def define_common_targets():
98102
"""Defines targets that should be shared between fbcode and xplat.
99103
@@ -110,6 +114,44 @@ def define_common_targets():
110114
aten_op_targets = [":{}".format(op["name"]) for op in enabled_ops]
111115
all_op_targets = aten_op_targets
112116

117+
runtime.cxx_library(
118+
name = "aten_headers_for_executorch",
119+
srcs = [],
120+
visibility = ["//executorch/kernels/optimized/..."],
121+
exported_deps = select({
122+
"DEFAULT": [],
123+
"ovr_config//cpu:arm64": [
124+
"fbsource//third-party/sleef:sleef_arm",
125+
] if not runtime.is_oss else [],
126+
# fbsource//third-party/sleef:sleef currently fails to
127+
# link with missing symbols, hence the fbcode-specific dep below.
128+
}),
129+
fbcode_exported_deps = [
130+
"//caffe2:aten-headers-cpu",
131+
"//caffe2:generated-config-header",
132+
"//caffe2/c10/core:base",
133+
] + select({
134+
"DEFAULT": [],
135+
"ovr_config//cpu:x86_64": [
136+
"third-party//sleef:sleef",
137+
]
138+
}),
139+
xplat_exported_deps = [
140+
"//xplat/caffe2:aten_header",
141+
"//xplat/caffe2:generated_aten_config_header",
142+
"//xplat/caffe2/c10:c10",
143+
],
144+
exported_preprocessor_flags = select({
145+
"ovr_config//cpu:x86_64": [
146+
"-DCPU_CAPABILITY=AVX2",
147+
"-DCPU_CAPABILITY_AVX2",
148+
"-DHAVE_AVX2_CPU_DEFINITION",
149+
] + get_sleef_preprocessor_flags(),
150+
"ovr_config//cpu:arm64": get_sleef_preprocessor_flags(),
151+
"DEFAULT": [],
152+
}) + ["-DSTANDALONE_TORCH_HEADER"],
153+
)
154+
113155
runtime.cxx_library(
114156
name = "binary_ops",
115157
exported_headers = ["binary_ops.h"],

kernels/optimized/optimized-oss.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
- arg_meta: null
4141
kernel_name: torch::executor::opt_sigmoid_out
4242

43+
- op: gelu.out
44+
kernels:
45+
- arg_meta: null
46+
kernel_name: torch::executor::opt_gelu_out
47+
4348
- op: le.Scalar_out
4449
kernels:
4550
- arg_meta: null

0 commit comments

Comments
 (0)