Skip to content

Commit fd88afa

Browse files
committed
Merge remote-tracking branch 'origin/main' into aar-for-bench-2
2 parents 746bf6f + e7e8647 commit fd88afa

File tree

7 files changed

+193
-96
lines changed

7 files changed

+193
-96
lines changed

extension/android/benchmark/app/build.gradle.kts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ android {
88

99
defaultConfig {
1010
applicationId = "org.pytorch.minibench"
11-
minSdk = 24
12-
targetSdk = 34
11+
minSdk = 28
12+
targetSdk = 33
1313
versionCode = 1
1414
versionName = "1.0"
1515

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
// Top-level build file where you can add configuration options common to all sub-projects/modules.
22
plugins {
3-
id("com.android.application") version "8.2.2" apply false
3+
id("com.android.application") version "8.1.0" apply false
44
}

extension/llm/custom_ops/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ runtime.python_test(
1414
"test_sdpa_with_kv_cache.py",
1515
],
1616
preload_deps = [
17-
":custom_ops_aot_lib",
17+
":custom_ops_aot_lib_mkl_noomp",
1818
":custom_ops_aot_py",
1919
],
2020
deps = [

extension/llm/custom_ops/targets.bzl

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,48 @@ def define_common_targets():
66
The directory containing this targets.bzl file should also contain both
77
TARGETS and BUCK files that call this function.
88
"""
9-
runtime.cxx_library(
10-
name = "custom_ops",
11-
srcs = ["op_sdpa.cpp", "op_fallback.cpp"],
12-
exported_headers = ["op_sdpa.h", "op_fallback.h"],
13-
exported_deps = [
14-
"//executorch/runtime/kernel:kernel_includes",
15-
"//executorch/kernels/portable/cpu:scalar_utils",
16-
"//executorch/kernels/optimized:libblas",
17-
"//executorch/kernels/optimized:libvec",
18-
"//executorch/extension/kernel_util:kernel_util",
19-
"//executorch/extension/parallel:thread_parallel",
20-
"//executorch/extension/threadpool:threadpool",
21-
],
22-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
23-
visibility = [
24-
"//executorch/...",
25-
"//executorch/extension/llm/custom_ops/...",
26-
"@EXECUTORCH_CLIENTS",
27-
],
28-
# @lint-ignore BUCKLINT link_whole
29-
link_whole = True,
30-
force_static = True,
31-
)
9+
for mkl_dep in ["", "_mkl_noomp"]:
10+
runtime.cxx_library(
11+
name = "custom_ops" + mkl_dep,
12+
srcs = ["op_sdpa.cpp", "op_fallback.cpp"],
13+
exported_headers = ["op_sdpa.h", "op_fallback.h"],
14+
exported_deps = [
15+
"//executorch/runtime/kernel:kernel_includes",
16+
"//executorch/kernels/portable/cpu:scalar_utils",
17+
"//executorch/kernels/optimized:libblas{}".format(mkl_dep),
18+
"//executorch/kernels/optimized:libvec",
19+
"//executorch/extension/kernel_util:kernel_util",
20+
"//executorch/extension/parallel:thread_parallel",
21+
"//executorch/extension/threadpool:threadpool",
22+
],
23+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
24+
visibility = [
25+
"//executorch/...",
26+
"//executorch/extension/llm/custom_ops/...",
27+
"@EXECUTORCH_CLIENTS",
28+
],
29+
# @lint-ignore BUCKLINT link_whole
30+
link_whole = True,
31+
force_static = True,
32+
)
3233

33-
runtime.cxx_library(
34-
name = "custom_ops_aot_lib",
35-
srcs = [
36-
"op_sdpa_aot.cpp",
37-
],
38-
visibility = [
39-
"//executorch/...",
40-
"@EXECUTORCH_CLIENTS",
41-
],
42-
external_deps = [
43-
"libtorch",
44-
],
45-
deps = [
46-
":custom_ops",
47-
"//executorch/extension/aten_util:aten_bridge",
48-
],
49-
)
34+
runtime.cxx_library(
35+
name = "custom_ops_aot_lib" + mkl_dep,
36+
srcs = [
37+
"op_sdpa_aot.cpp",
38+
],
39+
visibility = [
40+
"//executorch/...",
41+
"@EXECUTORCH_CLIENTS",
42+
],
43+
external_deps = [
44+
"libtorch",
45+
],
46+
deps = [
47+
":custom_ops" + mkl_dep,
48+
"//executorch/extension/aten_util:aten_bridge",
49+
],
50+
)
5051

5152
runtime.python_library(
5253
name = "custom_ops_aot_py",

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -392,17 +392,50 @@ def setUp(self):
392392
self.max_seq_len = 2048
393393
self.setup_caches()
394394

395+
def _scale_tensor(self, tensor, min_value, max_value, scale=True):
396+
normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
397+
398+
scaled_tensor = normalized_tensor * (max_value - min_value) + min_value
399+
400+
return scaled_tensor if scale else tensor
401+
395402
def _test_sdpa_common(
396-
self, n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len=1
403+
self,
404+
n_heads_kv,
405+
n_heads_q,
406+
head_dim,
407+
max_seq_len,
408+
seq_len,
409+
next_iter_seq_len=1,
410+
scale_tensors=False,
397411
):
412+
# Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests
413+
tensor_scale_max = 20
414+
tensor_scale_min = -20
398415
self.n_heads_kv = n_heads_kv
399416
self.n_heads_q = n_heads_q
400417
self.head_dim = head_dim
401418
self.max_seq_len = max_seq_len
402419
self.setup_caches()
403-
q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
404-
k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
405-
v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
420+
q = self._scale_tensor(
421+
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
422+
tensor_scale_max,
423+
tensor_scale_min,
424+
scale_tensors,
425+
)
426+
k = self._scale_tensor(
427+
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
428+
tensor_scale_max,
429+
tensor_scale_min,
430+
scale_tensors,
431+
)
432+
v = self._scale_tensor(
433+
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
434+
tensor_scale_max,
435+
tensor_scale_min,
436+
scale_tensors,
437+
)
438+
406439
start_pos = 0
407440
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
408441
attn_mask = attn_mask[:, : start_pos + seq_len]
@@ -412,11 +445,27 @@ def _test_sdpa_common(
412445
op_output = torch.ops.llama.sdpa_with_kv_cache(
413446
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
414447
)
415-
self.assertTrue(torch.allclose(ref_output, op_output))
448+
self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))
449+
450+
q = self._scale_tensor(
451+
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
452+
tensor_scale_max,
453+
tensor_scale_min,
454+
scale_tensors,
455+
)
456+
k = self._scale_tensor(
457+
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
458+
tensor_scale_max,
459+
tensor_scale_min,
460+
scale_tensors,
461+
)
462+
v = self._scale_tensor(
463+
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
464+
tensor_scale_max,
465+
tensor_scale_min,
466+
scale_tensors,
467+
)
416468

417-
q = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim))
418-
k = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim))
419-
v = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim))
420469
start_pos = seq_len
421470
seq_len = q.size(1)
422471
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
@@ -427,7 +476,7 @@ def _test_sdpa_common(
427476
op_output = torch.ops.llama.sdpa_with_kv_cache(
428477
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
429478
)
430-
self.assertTrue(torch.allclose(ref_output, op_output))
479+
self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))
431480

432481

433482
class SDPATestForLargeSeqLength(SDPATestCommon):
@@ -438,7 +487,9 @@ def test_sdpa_with_cache_seq_len_130(self):
438487
head_dim = 128
439488
max_seq_len = 2048
440489
seq_len = 130
441-
self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len)
490+
self._test_sdpa_common(
491+
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True
492+
)
442493

443494
def test_sdpa_with_cache_seq_len_small(self):
444495
n_heads_kv = 4
@@ -462,7 +513,9 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
462513
head_dim = 128
463514
max_seq_len = 2048
464515
seq_len = 130
465-
self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len)
516+
self._test_sdpa_common(
517+
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True
518+
)
466519

467520
def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
468521
n_heads_kv = 16
@@ -483,7 +536,13 @@ def test_sdpa_with_cache_seq_len_130(self):
483536
seq_len = 130
484537
next_iter_seq_len = 17
485538
self._test_sdpa_common(
486-
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
539+
n_heads_kv,
540+
n_heads_q,
541+
head_dim,
542+
max_seq_len,
543+
seq_len,
544+
next_iter_seq_len,
545+
True,
487546
)
488547

489548
def test_sdpa_with_cache_seq_len_llava_example(self):
@@ -505,7 +564,13 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
505564
seq_len = 130
506565
next_iter_seq_len = 33
507566
self._test_sdpa_common(
508-
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
567+
n_heads_kv,
568+
n_heads_q,
569+
head_dim,
570+
max_seq_len,
571+
seq_len,
572+
next_iter_seq_len,
573+
True,
509574
)
510575

511576
def test_sdpa_with_cache_seq_len_llava_example_gqa(self):

kernels/optimized/lib_defs.bzl

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbsource//tools/build_defs:default_platform_defs.bzl", "DEVSERVER_PLATFORM_REGEX")
2+
load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
23
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
34

45
# Because vec exists as a collection of header files, compile and preprocessor
@@ -99,44 +100,64 @@ def define_libs():
99100
],
100101
)
101102

102-
runtime.cxx_library(
103-
name = "libblas",
104-
srcs = native.glob([
105-
"blas/**/*.cpp",
106-
]),
107-
exported_headers = native.glob([
108-
"blas/**/*.h",
109-
]),
110-
header_namespace = "executorch/kernels/optimized",
111-
visibility = [
112-
"//executorch/...",
113-
"@EXECUTORCH_CLIENTS",
114-
],
115-
fbandroid_platform_preprocessor_flags = [
116-
(
117-
"^android-arm64.*$",
118-
[
119-
"-DET_BUILD_WITH_BLAS",
120-
],
121-
),
122-
],
123-
fbandroid_platform_deps = [
124-
(
125-
"^android-arm64.*$",
126-
[
127-
"fbsource//third-party/openblas:openblas",
128-
],
129-
),
130-
],
131-
fbobjc_exported_preprocessor_flags = [
132-
"-DET_BUILD_WITH_BLAS",
133-
"-DET_BUILD_FOR_APPLE",
134-
],
135-
fbobjc_frameworks = [
136-
"Accelerate",
137-
],
138-
exported_deps = [
139-
"//executorch/kernels/optimized:libutils",
140-
"//executorch/runtime/core/exec_aten:lib",
103+
# OSS doesn't have ovr_config//os:linux-x86_64
104+
fb_native.config_setting(
105+
name = "linux-x86_64",
106+
constraint_values = [
107+
"ovr_config//os/constraints:linux",
108+
"ovr_config//cpu/constraints:x86_64",
141109
],
142110
)
111+
112+
for libblas_name, mkl_dep in [("libblas", "fbsource//third-party/mkl:mkl_lp64_omp"), ("libblas_mkl_noomp", "fbsource//third-party/mkl:mkl")]:
113+
runtime.cxx_library(
114+
name = libblas_name,
115+
srcs = native.glob([
116+
"blas/**/*.cpp",
117+
]),
118+
exported_headers = native.glob([
119+
"blas/**/*.h",
120+
]),
121+
header_namespace = "executorch/kernels/optimized",
122+
visibility = [
123+
"//executorch/...",
124+
"@EXECUTORCH_CLIENTS",
125+
],
126+
preprocessor_flags = select({
127+
":linux-x86_64": [
128+
"-DET_BUILD_WITH_BLAS",
129+
] if not runtime.is_oss else [],
130+
"DEFAULT": [],
131+
}),
132+
fbandroid_platform_preprocessor_flags = [
133+
(
134+
"^android-arm64.*$",
135+
[
136+
"-DET_BUILD_WITH_BLAS",
137+
],
138+
),
139+
],
140+
fbandroid_platform_deps = [
141+
(
142+
"^android-arm64.*$",
143+
[
144+
"fbsource//third-party/openblas:openblas",
145+
],
146+
),
147+
],
148+
fbobjc_exported_preprocessor_flags = [
149+
"-DET_BUILD_WITH_BLAS",
150+
"-DET_BUILD_FOR_APPLE",
151+
],
152+
fbobjc_frameworks = [
153+
"Accelerate",
154+
],
155+
deps = select({
156+
":linux-x86_64": [mkl_dep] if not runtime.is_oss else [],
157+
"DEFAULT": [],
158+
}),
159+
exported_deps = [
160+
"//executorch/kernels/optimized:libutils",
161+
"//executorch/runtime/core/exec_aten:lib",
162+
],
163+
)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under both the MIT license found in the
4+
# LICENSE-MIT file in the root directory of this source tree and the Apache
5+
# License, Version 2.0 found in the LICENSE-APACHE file in the root directory
6+
# of this source tree.
7+
8+
fb_native = struct(
9+
config_setting = native.config_setting,
10+
)

0 commit comments

Comments
 (0)