Skip to content

Commit e1ce954

Browse files
authored
Merge branch 'main' into gh/trivedivivek/38/orig
2 parents 74ab802 + 6cb9037 commit e1ce954

File tree

22 files changed

+275
-208
lines changed

22 files changed

+275
-208
lines changed

.ci/docker/conda-env-ci.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
cmake=3.22.1
22
ninja=1.10.2
3+
libuv
4+
pkg-config

.ci/scripts/setup-macos.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,9 @@ if [[ -z "${GITHUB_RUNNER:-}" ]]; then
131131
fi
132132

133133
print_cmake_info
134-
install_executorch
134+
install_pytorch_and_domains
135+
# We build PyTorch from source here instead of using nightly. This allows CI to test against
136+
# the pinned commit from PyTorch
137+
install_executorch "use-pt-pinned-commit"
135138
build_executorch_runner "${BUILD_TOOL}"
139+
do_not_use_nightly_on_ci

.ci/scripts/utils.sh

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,42 @@ install_pip_dependencies() {
4040
popd || return
4141
}
4242

43+
install_domains() {
44+
echo "Install torchvision and torchaudio"
45+
pip install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}"
46+
pip install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${TORCHVISION_VERSION}"
47+
}
48+
49+
install_pytorch_and_domains() {
50+
pushd .ci/docker || return
51+
TORCH_VERSION=$(cat ci_commit_pins/pytorch.txt)
52+
popd || return
53+
54+
git clone https://github.com/pytorch/pytorch.git
55+
56+
# Fetch the target commit
57+
pushd pytorch || return
58+
git checkout "${TORCH_VERSION}"
59+
git submodule update --init --recursive
60+
61+
export USE_DISTRIBUTED=1
62+
# Then build and install PyTorch
63+
python setup.py bdist_wheel
64+
pip install "$(echo dist/*.whl)"
65+
66+
# Grab the pinned audio and vision commits from PyTorch
67+
TORCHAUDIO_VERSION=$(cat .github/ci_commit_pins/audio.txt)
68+
export TORCHAUDIO_VERSION
69+
TORCHVISION_VERSION=$(cat .github/ci_commit_pins/vision.txt)
70+
export TORCHVISION_VERSION
71+
72+
install_domains
73+
74+
popd || return
75+
# Print sccache stats for debugging
76+
sccache --show-stats || true
77+
}
78+
4379
install_flatc_from_source() {
4480
# NB: This function could be used to install flatbuffer from source
4581
pushd third-party/flatbuffers || return

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -36,7 +36,6 @@ def call(self, graph_module: GraphModule) -> PassResult:
3636
itertools.chain.from_iterable(matmul_partitions.values())
3737
)
3838
matmul_targets = {
39-
exir_ops.edge.aten.mm.default,
4039
exir_ops.edge.aten.bmm.default,
4140
}
4241
for partition in matmul_partitions:

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the BSD-style license found in the
@@ -45,6 +45,7 @@
4545
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
4646
ConvertMeanDimToAveragePool,
4747
)
48+
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
4849
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
4950
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
5051
ScalarsToAttributePass,
@@ -79,6 +80,7 @@ def transform_to_backend_pipeline(
7980
self.add_pass(ConvertMeanDimToAveragePool())
8081
self.add_pass(DecomposeMeanDimPass())
8182
self.add_pass(ConvertSplitToSlicePass())
83+
self.add_pass(ConvertMmToBmmPass())
8284
# TODO MLETORCH-558
8385
self.add_pass(AnnotateDecomposedMatmulPass())
8486
self.add_pass(QuantizeFullArgument())
@@ -99,7 +101,6 @@ def transform_to_backend_pipeline(
99101
exir_ops.edge.aten.hardtanh.default,
100102
exir_ops.edge.aten.log.default,
101103
exir_ops.edge.aten.max_pool2d.default,
102-
exir_ops.edge.aten.mm.default,
103104
exir_ops.edge.aten.mul.Tensor,
104105
exir_ops.edge.aten.permute_copy.default,
105106
exir_ops.edge.aten.reciprocal.default,

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from typing import cast
1010

11-
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
1211
from executorch.exir.dialects._ops import ops as exir_ops
1312
from executorch.exir.pass_base import ExportPass
1413

@@ -25,14 +24,14 @@ def call_operator(self, op, args, kwargs, meta):
2524
if op != self.expand_copy:
2625
return super().call_operator(op, args, kwargs, meta)
2726

28-
_, shape, _ = extract_tensor_meta(meta.data)
27+
input_shape = args[0].data.shape
2928
multiples = cast(list[int], args[1])
3029
expanded_rank = len(multiples)
3130

32-
# Expanded shape is 'shape' front-padded with ones.
33-
padding = expanded_rank - len(shape)
31+
# Expanded shape is 'input_shape' front-padded with ones.
32+
padding = expanded_rank - len(input_shape)
3433
extended_shape = [
35-
shape[i] if i >= 0 else 1 for i in range(-padding, len(shape))
34+
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
3635
]
3736

3837
# To convert expand arg to repeat arg, non-repeated dims should have

backends/arm/_passes/decompose_var_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta):
8383
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
8484
full = super().call_operator(
8585
full_op,
86-
([1 for _ in shape], 1 / max(0, N - correction)),
86+
([], 1 / max(0, N - correction)),
8787
{"dtype": dtype},
8888
meta,
8989
)

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
9090
continue
9191

9292
# Calculate max rank of all inputs to node
93-
max_rank = 1
93+
max_rank = 0
9494
for arg in node.args:
9595
if isinstance(arg, Node):
9696
shape = get_first_fake_tensor(arg).shape
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.arm._passes.arm_pass_utils import (
9+
create_node,
10+
get_first_fake_tensor,
11+
insert_q_dq_pair,
12+
)
13+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
from torch.fx import Node
17+
18+
19+
class ConvertMmToBmmPass(ExportPass):
20+
"""
21+
This pass converts a MM node to a BMM one and turns input and output tensors
22+
from rank 2 to rank 3. The TOSA specification requires rank 3. The graph is
23+
modified to do the following:
24+
1) Unsqueeze input tensors to rank 3.
25+
2) Convert MM node to BMM.
26+
3) Squeeze output tensor to rank 2.
27+
"""
28+
29+
def call(self, graph_module: torch.fx.GraphModule):
30+
modified_graph = False
31+
graph = graph_module.graph
32+
node_list = graph.find_nodes(
33+
op="call_function", target=exir_ops.edge.aten.mm.default
34+
)
35+
for node in node_list:
36+
# Unsqueeze input tensors to rank 3
37+
for input_node in node.args:
38+
if not isinstance(input_node, Node):
39+
continue
40+
41+
shape = get_first_fake_tensor(input_node).shape
42+
rank = len(shape)
43+
if rank != 2:
44+
raise RuntimeError(f"Input tensor has rank {rank}, must be 2")
45+
46+
with graph.inserting_before(node):
47+
unsqueeze_before = create_node(
48+
graph, exir_ops.edge.aten.unsqueeze_copy.default
49+
)
50+
unsqueeze_before.args = (
51+
input_node, # Input is node's original input
52+
0,
53+
)
54+
node.replace_input_with(input_node, unsqueeze_before)
55+
56+
# If Quantized we must insert unsqueeze --> q --> dq --> node
57+
if input_node.target == dq_op:
58+
q_params = input_node.args[1:]
59+
insert_q_dq_pair(graph, unsqueeze_before, q_params)
60+
61+
# Replace mm node with bmm
62+
with graph.inserting_before(node):
63+
bmm_node = create_node(
64+
graph,
65+
exir_ops.edge.aten.bmm.default,
66+
)
67+
bmm_node.args = node.args
68+
node.replace_all_uses_with(bmm_node)
69+
graph.erase_node(node)
70+
71+
# Unsqueeze output tensor to rank 3
72+
with graph.inserting_after(bmm_node):
73+
squeeze_after = create_node(
74+
graph,
75+
exir_ops.edge.aten.squeeze_copy.dims,
76+
)
77+
squeeze_after.args = (
78+
bmm_node,
79+
[0],
80+
)
81+
original_users = [
82+
user for user in bmm_node.users if user != squeeze_after
83+
]
84+
for user in original_users:
85+
user.replace_input_with(bmm_node, squeeze_after)
86+
87+
# If quantized, insert mm --> q --> dq --> squeeze
88+
if all(original_user.target == q_op for original_user in original_users):
89+
q_params = original_users[0].args[1:]
90+
insert_q_dq_pair(graph, bmm_node, q_params)
91+
92+
modified_graph = True
93+
94+
if modified_graph:
95+
graph_module.recompile()
96+
graph_module = super().call(graph_module).graph_module
97+
98+
return PassResult(graph_module, modified_graph)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -22,7 +22,6 @@
2222
op_max,
2323
op_max_pool2d,
2424
op_min,
25-
op_mm,
2625
op_mul,
2726
op_permute,
2827
op_quant,

0 commit comments

Comments
 (0)