Skip to content

Commit 8471c22

Browse files
authored
Enable MKL on x86 to get around long-context discrepancies with torch.nn.functional.scaled_dot_product_attention
Differential Revision: D61290864 Pull Request resolved: #4758
1 parent 7b4be54 commit 8471c22

File tree

2 files changed

+92
-80
lines changed

2 files changed

+92
-80
lines changed

extension/llm/custom_ops/targets.bzl

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,48 @@ def define_common_targets():
66
The directory containing this targets.bzl file should also contain both
77
TARGETS and BUCK files that call this function.
88
"""
9-
runtime.cxx_library(
10-
name = "custom_ops",
11-
srcs = ["op_sdpa.cpp"],
12-
exported_headers = ["op_sdpa.h"],
13-
exported_deps = [
14-
"//executorch/runtime/kernel:kernel_includes",
15-
"//executorch/kernels/portable/cpu:scalar_utils",
16-
"//executorch/kernels/optimized:libblas",
17-
"//executorch/kernels/optimized:libvec",
18-
"//executorch/extension/kernel_util:kernel_util",
19-
"//executorch/extension/parallel:thread_parallel",
20-
"//executorch/backends/xnnpack/threadpool:threadpool",
21-
],
22-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
23-
visibility = [
24-
"//executorch/...",
25-
"//executorch/extension/llm/custom_ops/...",
26-
"@EXECUTORCH_CLIENTS",
27-
],
28-
# @lint-ignore BUCKLINT link_whole
29-
link_whole = True,
30-
force_static = True,
31-
)
9+
for mkl_dep in ["", "_mkl_lp64_omp"]:
10+
runtime.cxx_library(
11+
name = "custom_ops" + mkl_dep,
12+
srcs = ["op_sdpa.cpp"],
13+
exported_headers = ["op_sdpa.h"],
14+
exported_deps = [
15+
"//executorch/runtime/kernel:kernel_includes",
16+
"//executorch/kernels/portable/cpu:scalar_utils",
17+
"//executorch/kernels/optimized:libblas{}".format(mkl_dep),
18+
"//executorch/kernels/optimized:libvec",
19+
"//executorch/extension/kernel_util:kernel_util",
20+
"//executorch/extension/parallel:thread_parallel",
21+
"//executorch/backends/xnnpack/threadpool:threadpool",
22+
],
23+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
24+
visibility = [
25+
"//executorch/...",
26+
"//executorch/extension/llm/custom_ops/...",
27+
"@EXECUTORCH_CLIENTS",
28+
],
29+
# @lint-ignore BUCKLINT link_whole
30+
link_whole = True,
31+
force_static = True,
32+
)
3233

33-
runtime.cxx_library(
34-
name = "custom_ops_aot_lib",
35-
srcs = [
36-
"op_sdpa_aot.cpp",
37-
],
38-
visibility = [
39-
"//executorch/...",
40-
"@EXECUTORCH_CLIENTS",
41-
],
42-
external_deps = [
43-
"libtorch",
44-
],
45-
deps = [
46-
":custom_ops",
47-
"//executorch/extension/aten_util:aten_bridge",
48-
],
49-
)
34+
runtime.cxx_library(
35+
name = "custom_ops_aot_lib" + mkl_dep,
36+
srcs = [
37+
"op_sdpa_aot.cpp",
38+
],
39+
visibility = [
40+
"//executorch/...",
41+
"@EXECUTORCH_CLIENTS",
42+
],
43+
external_deps = [
44+
"libtorch",
45+
],
46+
deps = [
47+
":custom_ops" + mkl_dep,
48+
"//executorch/extension/aten_util:aten_bridge",
49+
],
50+
)
5051

5152
runtime.python_library(
5253
name = "custom_ops_aot_py",

kernels/optimized/lib_defs.bzl

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -99,44 +99,55 @@ def define_libs():
9999
],
100100
)
101101

102-
runtime.cxx_library(
103-
name = "libblas",
104-
srcs = native.glob([
105-
"blas/**/*.cpp",
106-
]),
107-
exported_headers = native.glob([
108-
"blas/**/*.h",
109-
]),
110-
header_namespace = "executorch/kernels/optimized",
111-
visibility = [
112-
"//executorch/...",
113-
"@EXECUTORCH_CLIENTS",
114-
],
115-
fbandroid_platform_preprocessor_flags = [
116-
(
117-
"^android-arm64.*$",
118-
[
102+
for libblas_name, mkl_dep in [("libblas", "fbsource//third-party/mkl:mkl"), ("libblas_mkl_lp64_omp", "fbsource//third-party/mkl:mkl_lp64_omp")]:
103+
runtime.cxx_library(
104+
name = libblas_name,
105+
srcs = native.glob([
106+
"blas/**/*.cpp",
107+
]),
108+
exported_headers = native.glob([
109+
"blas/**/*.h",
110+
]),
111+
header_namespace = "executorch/kernels/optimized",
112+
visibility = [
113+
"//executorch/...",
114+
"@EXECUTORCH_CLIENTS",
115+
],
116+
preprocessor_flags = select({
117+
"DEFAULT": [],
118+
"ovr_config//os:linux-x86_64": [
119119
"-DET_BUILD_WITH_BLAS",
120-
],
121-
),
122-
],
123-
fbandroid_platform_deps = [
124-
(
125-
"^android-arm64.*$",
126-
[
127-
"fbsource//third-party/openblas:openblas",
128-
],
129-
),
130-
],
131-
fbobjc_exported_preprocessor_flags = [
132-
"-DET_BUILD_WITH_BLAS",
133-
"-DET_BUILD_FOR_APPLE",
134-
],
135-
fbobjc_frameworks = [
136-
"Accelerate",
137-
],
138-
exported_deps = [
139-
"//executorch/kernels/optimized:libutils",
140-
"//executorch/runtime/core/exec_aten:lib",
141-
],
142-
)
120+
] if not runtime.is_oss else [],
121+
}),
122+
fbandroid_platform_preprocessor_flags = [
123+
(
124+
"^android-arm64.*$",
125+
[
126+
"-DET_BUILD_WITH_BLAS",
127+
],
128+
),
129+
],
130+
fbandroid_platform_deps = [
131+
(
132+
"^android-arm64.*$",
133+
[
134+
"fbsource//third-party/openblas:openblas",
135+
],
136+
),
137+
],
138+
fbobjc_exported_preprocessor_flags = [
139+
"-DET_BUILD_WITH_BLAS",
140+
"-DET_BUILD_FOR_APPLE",
141+
],
142+
fbobjc_frameworks = [
143+
"Accelerate",
144+
],
145+
deps = select({
146+
"DEFAULT": [],
147+
"ovr_config//os:linux-x86_64": [mkl_dep] if not runtime.is_oss else [],
148+
}),
149+
exported_deps = [
150+
"//executorch/kernels/optimized:libutils",
151+
"//executorch/runtime/core/exec_aten:lib",
152+
],
153+
)

0 commit comments

Comments
 (0)