Skip to content

Commit 0a8547a

Browse files
authored
Enable MKL on x86 to get around long-context discrepancies with torch.nn.functional.scaled_dot_product_attention
Differential Revision: D61931885 Pull Request resolved: #4948
1 parent 9c1a52c commit 0a8547a

File tree

4 files changed

+112
-80
lines changed

4 files changed

+112
-80
lines changed

extension/llm/custom_ops/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ runtime.python_test(
1414
"test_sdpa_with_kv_cache.py",
1515
],
1616
preload_deps = [
17-
":custom_ops_aot_lib",
17+
":custom_ops_aot_lib_mkl_noomp",
1818
":custom_ops_aot_py",
1919
],
2020
deps = [

extension/llm/custom_ops/targets.bzl

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,48 @@ def define_common_targets():
66
The directory containing this targets.bzl file should also contain both
77
TARGETS and BUCK files that call this function.
88
"""
9-
runtime.cxx_library(
10-
name = "custom_ops",
11-
srcs = ["op_sdpa.cpp", "op_fallback.cpp"],
12-
exported_headers = ["op_sdpa.h", "op_fallback.h"],
13-
exported_deps = [
14-
"//executorch/runtime/kernel:kernel_includes",
15-
"//executorch/kernels/portable/cpu:scalar_utils",
16-
"//executorch/kernels/optimized:libblas",
17-
"//executorch/kernels/optimized:libvec",
18-
"//executorch/extension/kernel_util:kernel_util",
19-
"//executorch/extension/parallel:thread_parallel",
20-
"//executorch/extension/threadpool:threadpool",
21-
],
22-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
23-
visibility = [
24-
"//executorch/...",
25-
"//executorch/extension/llm/custom_ops/...",
26-
"@EXECUTORCH_CLIENTS",
27-
],
28-
# @lint-ignore BUCKLINT link_whole
29-
link_whole = True,
30-
force_static = True,
31-
)
9+
for mkl_dep in ["", "_mkl_noomp"]:
10+
runtime.cxx_library(
11+
name = "custom_ops" + mkl_dep,
12+
srcs = ["op_sdpa.cpp", "op_fallback.cpp"],
13+
exported_headers = ["op_sdpa.h", "op_fallback.h"],
14+
exported_deps = [
15+
"//executorch/runtime/kernel:kernel_includes",
16+
"//executorch/kernels/portable/cpu:scalar_utils",
17+
"//executorch/kernels/optimized:libblas{}".format(mkl_dep),
18+
"//executorch/kernels/optimized:libvec",
19+
"//executorch/extension/kernel_util:kernel_util",
20+
"//executorch/extension/parallel:thread_parallel",
21+
"//executorch/extension/threadpool:threadpool",
22+
],
23+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
24+
visibility = [
25+
"//executorch/...",
26+
"//executorch/extension/llm/custom_ops/...",
27+
"@EXECUTORCH_CLIENTS",
28+
],
29+
# @lint-ignore BUCKLINT link_whole
30+
link_whole = True,
31+
force_static = True,
32+
)
3233

33-
runtime.cxx_library(
34-
name = "custom_ops_aot_lib",
35-
srcs = [
36-
"op_sdpa_aot.cpp",
37-
],
38-
visibility = [
39-
"//executorch/...",
40-
"@EXECUTORCH_CLIENTS",
41-
],
42-
external_deps = [
43-
"libtorch",
44-
],
45-
deps = [
46-
":custom_ops",
47-
"//executorch/extension/aten_util:aten_bridge",
48-
],
49-
)
34+
runtime.cxx_library(
35+
name = "custom_ops_aot_lib" + mkl_dep,
36+
srcs = [
37+
"op_sdpa_aot.cpp",
38+
],
39+
visibility = [
40+
"//executorch/...",
41+
"@EXECUTORCH_CLIENTS",
42+
],
43+
external_deps = [
44+
"libtorch",
45+
],
46+
deps = [
47+
":custom_ops" + mkl_dep,
48+
"//executorch/extension/aten_util:aten_bridge",
49+
],
50+
)
5051

5152
runtime.python_library(
5253
name = "custom_ops_aot_py",

kernels/optimized/lib_defs.bzl

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbsource//tools/build_defs:default_platform_defs.bzl", "DEVSERVER_PLATFORM_REGEX")
2+
load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
23
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
34

45
# Because vec exists as a collection of header files, compile and preprocessor
@@ -99,44 +100,64 @@ def define_libs():
99100
],
100101
)
101102

102-
runtime.cxx_library(
103-
name = "libblas",
104-
srcs = native.glob([
105-
"blas/**/*.cpp",
106-
]),
107-
exported_headers = native.glob([
108-
"blas/**/*.h",
109-
]),
110-
header_namespace = "executorch/kernels/optimized",
111-
visibility = [
112-
"//executorch/...",
113-
"@EXECUTORCH_CLIENTS",
114-
],
115-
fbandroid_platform_preprocessor_flags = [
116-
(
117-
"^android-arm64.*$",
118-
[
119-
"-DET_BUILD_WITH_BLAS",
120-
],
121-
),
122-
],
123-
fbandroid_platform_deps = [
124-
(
125-
"^android-arm64.*$",
126-
[
127-
"fbsource//third-party/openblas:openblas",
128-
],
129-
),
130-
],
131-
fbobjc_exported_preprocessor_flags = [
132-
"-DET_BUILD_WITH_BLAS",
133-
"-DET_BUILD_FOR_APPLE",
134-
],
135-
fbobjc_frameworks = [
136-
"Accelerate",
137-
],
138-
exported_deps = [
139-
"//executorch/kernels/optimized:libutils",
140-
"//executorch/runtime/core/exec_aten:lib",
103+
# OSS doesn't have ovr_config//os:linux-x86_64
104+
fb_native.config_setting(
105+
name = "linux-x86_64",
106+
constraint_values = [
107+
"ovr_config//os/constraints:linux",
108+
"ovr_config//cpu/constraints:x86_64",
141109
],
142110
)
111+
112+
for libblas_name, mkl_dep in [("libblas", "fbsource//third-party/mkl:mkl_lp64_omp"), ("libblas_mkl_noomp", "fbsource//third-party/mkl:mkl")]:
113+
runtime.cxx_library(
114+
name = libblas_name,
115+
srcs = native.glob([
116+
"blas/**/*.cpp",
117+
]),
118+
exported_headers = native.glob([
119+
"blas/**/*.h",
120+
]),
121+
header_namespace = "executorch/kernels/optimized",
122+
visibility = [
123+
"//executorch/...",
124+
"@EXECUTORCH_CLIENTS",
125+
],
126+
preprocessor_flags = select({
127+
":linux-x86_64": [
128+
"-DET_BUILD_WITH_BLAS",
129+
] if not runtime.is_oss else [],
130+
"DEFAULT": [],
131+
}),
132+
fbandroid_platform_preprocessor_flags = [
133+
(
134+
"^android-arm64.*$",
135+
[
136+
"-DET_BUILD_WITH_BLAS",
137+
],
138+
),
139+
],
140+
fbandroid_platform_deps = [
141+
(
142+
"^android-arm64.*$",
143+
[
144+
"fbsource//third-party/openblas:openblas",
145+
],
146+
),
147+
],
148+
fbobjc_exported_preprocessor_flags = [
149+
"-DET_BUILD_WITH_BLAS",
150+
"-DET_BUILD_FOR_APPLE",
151+
],
152+
fbobjc_frameworks = [
153+
"Accelerate",
154+
],
155+
deps = select({
156+
":linux-x86_64": [mkl_dep] if not runtime.is_oss else [],
157+
"DEFAULT": [],
158+
}),
159+
exported_deps = [
160+
"//executorch/kernels/optimized:libutils",
161+
"//executorch/runtime/core/exec_aten:lib",
162+
],
163+
)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under both the MIT license found in the
4+
# LICENSE-MIT file in the root directory of this source tree and the Apache
5+
# License, Version 2.0 found in the LICENSE-APACHE file in the root directory
6+
# of this source tree.
7+
8+
fb_native = struct(
9+
config_setting = native.config_setting,
10+
)

0 commit comments

Comments
 (0)