Skip to content

Commit ec6ec3c

Browse files
committed
Merge remote-tracking branch 'origin/main' into use-executorch-core
2 parents ff56f04 + 1984c5f commit ec6ec3c

File tree

8 files changed

+140
-7
lines changed

8 files changed

+140
-7
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
22+
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2223
from .decompose_div_pass import DecomposeDivPass # noqa
2324
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2425
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
27+
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeGeluPass,
2930
DecomposeLayerNormPass,
@@ -205,6 +206,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
205206
self.add_pass(DecomposeVarPass())
206207
self.add_pass(DecomposeMeanDimPass())
207208
self.add_pass(DecomposeNotEqualPass())
209+
self.add_pass(DecomposeCosineSimilarityPass())
208210
self.add_pass(DecomposeDivPass())
209211
self.add_pass(DecomposeLeakyReLUPass())
210212
self.add_pass(DecomposeSqrtPass())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,)
10+
11+
12+
class DecomposeCosineSimilarityPass(ExportPass):
13+
"""
14+
Decomposition of aten.cosine_similarity:
15+
16+
dot = sum(mul(x1, x2), dims, keepdim=False)
17+
norm = pow( sum(mul(x, x), dims, keepdim=False), 0.5 )
18+
eps = full( (), eps_scalar )
19+
n1c = max(norm1, eps)
20+
n2c = max(norm2, eps)
21+
denom = mul(n1c, n2c)
22+
out = div(dot, denom)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in torch_cosine_similarity:
27+
return super().call_operator(op, args, kwargs, meta)
28+
29+
x1, x2 = args[0], args[1]
30+
dim = kwargs.get("dim", 1)
31+
eps = kwargs.get("eps", 1e-8)
32+
dims = [dim] if isinstance(dim, int) else list(dim)
33+
34+
# 1) dot
35+
prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta)
36+
dot = super().call_operator(
37+
torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta
38+
)
39+
40+
# 2a) norm1 = pow(sum(x1*x1), 0.5)
41+
x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta)
42+
s1 = super().call_operator(
43+
torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta
44+
)
45+
norm1 = super().call_operator(
46+
torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta
47+
)
48+
49+
# 2b) norm2 = pow(sum(x2*x2), 0.5)
50+
x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta)
51+
s2 = super().call_operator(
52+
torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta
53+
)
54+
norm2 = super().call_operator(
55+
torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta
56+
)
57+
58+
# 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar
59+
eps_t = super().call_operator(
60+
torch.ops.aten.full_like.default, (norm1, eps), {}, meta
61+
)
62+
63+
# 4) clamp to avoid zero division
64+
n1c = super().call_operator(
65+
torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta
66+
)
67+
n2c = super().call_operator(
68+
torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta
69+
)
70+
71+
# 5) denom and divide
72+
denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta)
73+
out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta)
74+
75+
return out

backends/arm/test/models/test_nn_functional.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def test_nn_functional_MI(test_data):
106106

107107
x_fails = {
108108
"normalize": "MLETORCH-852: Support aten.index_put.default",
109-
"cosine_similarity": "MLETORCH-854: Support aten.linalg_vector_norm.default",
110109
"unfold": "Int64 input && MLETORCH-827: Support aten.index.Tensor",
111110
"fold": "Int64 input && MLETORCH-827: Support aten.index_put.default",
112111
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes.decompose_cosine_similarity_pass import (
11+
DecomposeCosineSimilarityPass,
12+
)
13+
from executorch.backends.arm.test import common
14+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
15+
16+
input_t = Tuple[torch.Tensor, torch.Tensor]
17+
18+
19+
class CosineSimilarityModel(torch.nn.Module):
20+
def get_inputs(self) -> input_t:
21+
return (torch.rand(2, 3, 4), torch.rand(2, 3, 4))
22+
23+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
24+
return torch.cosine_similarity(x1, x2, dim=1, eps=1e-6)
25+
26+
27+
modules = {"cosine_basic": CosineSimilarityModel()}
28+
29+
30+
@common.parametrize("module", modules)
31+
def test_decompose_cosine_similarity_tosa_BI(module):
32+
33+
ops_after_pass = {
34+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 5,
35+
"executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 3,
36+
"executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2,
37+
"executorch_exir_dialects_edge__ops_aten_full_like_default": 1,
38+
"executorch_exir_dialects_edge__ops_aten_maximum_default": 2,
39+
"executorch_exir_dialects_edge__ops_aten_reciprocal_default": 1,
40+
}
41+
42+
pipeline = PassPipeline[input_t](
43+
module,
44+
module.get_inputs(),
45+
tosa_version="TOSA-0.80+BI",
46+
ops_before_pass=None,
47+
ops_not_before_pass=None,
48+
ops_after_pass=ops_after_pass,
49+
ops_not_after_pass=None,
50+
pass_list=[DecomposeCosineSimilarityPass],
51+
)
52+
pipeline.run()

backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ void main() {
8686
const int in_l = out_l * stride - padding;
8787
VEC4_T sum = VEC4_T(0);
8888

89+
const int out_c_packed_index = out_c >> 2;
90+
const int out_c_packed_lane = out_c & 0x3;
91+
8992
for (int in_c = c_start; in_c < c_end; ++in_c) {
9093
// "k" tracks the kernel's index for our input-kernel computation.
9194
// It reads out-of-bound zeros, but trying to avoid them complicates
@@ -103,16 +106,16 @@ void main() {
103106
// It is possible to further reduce the memory footprint by swapping the
104107
// dimensions, using x extent for out_channel, and y for kernel.
105108
for (int k = 0; k < kernel_size; k++) {
106-
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4);
109+
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c_packed_index);
107110
const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
108-
VEC4_T weight = VEC4_T(weight_texel[out_c % 4]);
111+
VEC4_T weight = VEC4_T(weight_texel[out_c_packed_lane]);
109112

110113
const ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, N), in_axis_map);
111114
sum = fma(weight, load_texel(t_in, in_pos), sum);
112115
}
113116
}
114117

115-
const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);
118+
const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c_packed_index, 0, 0), bias_axis_map);
116119
const ivec3 out_lpos = ivec3(out_l, out_c, N);
117-
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
120+
write_texel_lpos(t_out, out_lpos, op(sum + bias[out_c_packed_lane], out_min, out_max), out_axis_map);
118121
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ void add_conv1d_node(
483483
weight,
484484
/*transposed = */ false,
485485
/*storage_type = */ utils::kTexture3D,
486-
/*memory_layout = */ utils::kChannelsPacked);
486+
/*memory_layout = */ utils::kWidthPacked);
487487

488488
float out_min_val = 0.0f;
489489
float out_max_val = 0.0f;

tools/cmake/cmake_deps.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ filters = [
197197
]
198198
deps = [
199199
"executorch_core",
200-
"extension_flat_tensor_schema",
201200
]
202201

203202
[targets.extension_module]
@@ -236,6 +235,8 @@ deps = [
236235
"extension_data_loader",
237236
"extension_flat_tensor",
238237
"extension_module",
238+
"extension_data_loader",
239+
"extension_flat_tensor",
239240
"extension_runner_util",
240241
"extension_tensor",
241242
]

0 commit comments

Comments
 (0)