Skip to content

Commit 8b4c0ed

Browse files
committed
Update base for Update on "Xnnpack test for program-data separation"
Add xnnpack test for program-data separation Differential Revision: [D73794695](https://our.internmc.facebook.com/intern/diff/D73794695/) [ghstack-poisoned]
2 parents bde492b + 6932baf commit 8b4c0ed

File tree

11 files changed

+142
-10
lines changed

11 files changed

+142
-10
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
22+
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2223
from .decompose_div_pass import DecomposeDivPass # noqa
2324
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2425
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
27+
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeGeluPass,
2930
DecomposeLayerNormPass,
@@ -205,6 +206,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
205206
self.add_pass(DecomposeVarPass())
206207
self.add_pass(DecomposeMeanDimPass())
207208
self.add_pass(DecomposeNotEqualPass())
209+
self.add_pass(DecomposeCosineSimilarityPass())
208210
self.add_pass(DecomposeDivPass())
209211
self.add_pass(DecomposeLeakyReLUPass())
210212
self.add_pass(DecomposeSqrtPass())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,)
10+
11+
12+
class DecomposeCosineSimilarityPass(ExportPass):
13+
"""
14+
Decomposition of aten.cosine_similarity:
15+
16+
dot = sum(mul(x1, x2), dims, keepdim=False)
17+
norm = pow( sum(mul(x, x), dims, keepdim=False), 0.5 )
18+
eps = full( (), eps_scalar )
19+
n1c = max(norm1, eps)
20+
n2c = max(norm2, eps)
21+
denom = mul(n1c, n2c)
22+
out = div(dot, denom)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in torch_cosine_similarity:
27+
return super().call_operator(op, args, kwargs, meta)
28+
29+
x1, x2 = args[0], args[1]
30+
dim = kwargs.get("dim", 1)
31+
eps = kwargs.get("eps", 1e-8)
32+
dims = [dim] if isinstance(dim, int) else list(dim)
33+
34+
# 1) dot
35+
prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta)
36+
dot = super().call_operator(
37+
torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta
38+
)
39+
40+
# 2a) norm1 = pow(sum(x1*x1), 0.5)
41+
x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta)
42+
s1 = super().call_operator(
43+
torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta
44+
)
45+
norm1 = super().call_operator(
46+
torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta
47+
)
48+
49+
# 2b) norm2 = pow(sum(x2*x2), 0.5)
50+
x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta)
51+
s2 = super().call_operator(
52+
torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta
53+
)
54+
norm2 = super().call_operator(
55+
torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta
56+
)
57+
58+
# 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar
59+
eps_t = super().call_operator(
60+
torch.ops.aten.full_like.default, (norm1, eps), {}, meta
61+
)
62+
63+
# 4) clamp to avoid zero division
64+
n1c = super().call_operator(
65+
torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta
66+
)
67+
n2c = super().call_operator(
68+
torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta
69+
)
70+
71+
# 5) denom and divide
72+
denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta)
73+
out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta)
74+
75+
return out

backends/arm/test/models/test_nn_functional.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def test_nn_functional_MI(test_data):
106106

107107
x_fails = {
108108
"normalize": "MLETORCH-852: Support aten.index_put.default",
109-
"cosine_similarity": "MLETORCH-854: Support aten.linalg_vector_norm.default",
110109
"unfold": "Int64 input && MLETORCH-827: Support aten.index.Tensor",
111110
"fold": "Int64 input && MLETORCH-827: Support aten.index_put.default",
112111
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes.decompose_cosine_similarity_pass import (
11+
DecomposeCosineSimilarityPass,
12+
)
13+
from executorch.backends.arm.test import common
14+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
15+
16+
input_t = Tuple[torch.Tensor, torch.Tensor]
17+
18+
19+
class CosineSimilarityModel(torch.nn.Module):
20+
def get_inputs(self) -> input_t:
21+
return (torch.rand(2, 3, 4), torch.rand(2, 3, 4))
22+
23+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
24+
return torch.cosine_similarity(x1, x2, dim=1, eps=1e-6)
25+
26+
27+
modules = {"cosine_basic": CosineSimilarityModel()}
28+
29+
30+
@common.parametrize("module", modules)
31+
def test_decompose_cosine_similarity_tosa_BI(module):
32+
33+
ops_after_pass = {
34+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 5,
35+
"executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 3,
36+
"executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2,
37+
"executorch_exir_dialects_edge__ops_aten_full_like_default": 1,
38+
"executorch_exir_dialects_edge__ops_aten_maximum_default": 2,
39+
"executorch_exir_dialects_edge__ops_aten_reciprocal_default": 1,
40+
}
41+
42+
pipeline = PassPipeline[input_t](
43+
module,
44+
module.get_inputs(),
45+
tosa_version="TOSA-0.80+BI",
46+
ops_before_pass=None,
47+
ops_not_before_pass=None,
48+
ops_after_pass=ops_after_pass,
49+
ops_not_after_pass=None,
50+
pass_list=[DecomposeCosineSimilarityPass],
51+
)
52+
pipeline.run()

examples/models/llama/runner/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ else()
5353
endif()
5454

5555
set(llama_runner_deps executorch_core extension_data_loader extension_module
56-
extension_tensor
56+
extension_tensor extension_flat_tensor
5757
)
5858

5959
target_link_libraries(llama_runner PUBLIC ${llama_runner_deps})

examples/models/llava/runner/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ add_subdirectory(
4141
add_library(llava_runner STATIC ${_llava_runner__srcs})
4242

4343
set(llava_runner_deps executorch_core extension_data_loader extension_llm_runner
44-
extension_module extension_tensor
44+
extension_module extension_tensor extension_flat_tensor
4545
)
4646

4747
target_link_libraries(llava_runner PUBLIC ${llava_runner_deps})

extension/android/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ list(
7272
link_libraries
7373
executorch
7474
extension_data_loader
75+
extension_flat_tensor
7576
extension_module
7677
extension_runner_util
7778
extension_tensor

extension/llm/runner/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ target_include_directories(
4343

4444
add_library(extension_llm_runner STATIC ${_extension_llm_runner__srcs})
4545

46-
set(runner_deps executorch extension_module extension_tensor)
46+
set(runner_deps executorch_core extension_module extension_tensor)
4747

4848
target_link_libraries(extension_llm_runner PUBLIC ${runner_deps})
4949

tools/cmake/cmake_deps.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,7 @@ filters = [
196196
".cpp$",
197197
]
198198
deps = [
199-
"extension_flat_tensor_schema",
200199
"executorch_core",
201-
"executorch",
202200
]
203201

204202
[targets.extension_module]
@@ -209,9 +207,9 @@ filters = [
209207
".cpp$",
210208
]
211209
deps = [
212-
"executorch",
213210
"executorch_core",
214211
"extension_data_loader",
212+
"extension_flat_tensor",
215213
]
216214

217215
[targets.extension_runner_util]
@@ -233,9 +231,12 @@ filters = [
233231
".cpp$",
234232
]
235233
deps = [
236-
"executorch",
237234
"executorch_core",
235+
"extension_data_loader",
236+
"extension_flat_tensor",
238237
"extension_module",
238+
"extension_data_loader",
239+
"extension_flat_tensor",
239240
"extension_runner_util",
240241
"extension_tensor",
241242
]
@@ -248,7 +249,6 @@ filters = [
248249
".cpp$",
249250
]
250251
deps = [
251-
"executorch",
252252
"executorch_core",
253253
]
254254

@@ -260,7 +260,6 @@ filters = [
260260
".cpp$",
261261
]
262262
deps = [
263-
"executorch",
264263
"executorch_core",
265264
]
266265

@@ -452,7 +451,9 @@ deps = [
452451
"executorch",
453452
"executorch_core",
454453
"extension_data_loader",
454+
"extension_flat_tensor",
455455
"extension_module",
456+
"extension_tensor",
456457
"extension_threadpool",
457458
"optimized_cpublas",
458459
"portable_kernels",

tools/cmake/executorch-config.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ set(lib_list
6666
etdump
6767
bundled_program
6868
extension_data_loader
69+
extension_flat_tensor
6970
${FLATCCRT_LIB}
7071
coreml_util
7172
coreml_inmemoryfs

0 commit comments

Comments
 (0)