Skip to content

Commit be44fb4

Browse files
committed
Update base for Update on "[Executorch] Refactor op_mul's broadcasting utils"
Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
2 parents ed79e8c + 78752a0 commit be44fb4

File tree

189 files changed

+5186
-1054
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

189 files changed

+5186
-1054
lines changed

.ci/docker/common/install_base.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ install_ubuntu() {
2626
libssl-dev \
2727
zip
2828

29+
# These libraries are needed by TorchVision
30+
apt-get install -y --no-install-recommends \
31+
libjpeg-dev \
32+
libpng-dev
33+
2934
# Cleanup package manager
3035
apt-get autoclean && apt-get clean
3136
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*

.ci/docker/common/install_conda.sh

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,16 @@ install_miniconda() {
3131

3232
install_python() {
3333
pushd /opt/conda
34-
# Install the correct Python version
34+
# Install the selected Python version for CI jobs
3535
as_ci_user conda create -n "py_${PYTHON_VERSION}" -y --file /opt/conda/conda-env-ci.txt python="${PYTHON_VERSION}"
36+
37+
# From https://github.com/pytorch/pytorch/blob/main/.ci/docker/common/install_conda.sh
38+
if [[ $(uname -m) == "aarch64" ]]; then
39+
conda_install "openblas==0.3.28=*openmp*"
40+
else
41+
conda_install mkl=2022.1.0 mkl-include=2022.1.0
42+
fi
43+
3644
popd
3745
}
3846

@@ -53,7 +61,7 @@ fix_conda_ubuntu_libstdcxx() {
5361
# PyTorch sev: https://github.com/pytorch/pytorch/issues/105248
5462
# Ref: https://github.com/pytorch/pytorch/blob/main/.ci/docker/common/install_conda.sh
5563
if grep -e "2[02].04." /etc/issue >/dev/null; then
56-
rm "/opt/conda/envs/py_${PYTHON_VERSION}/lib/libstdc++.so.6"
64+
rm /opt/conda/envs/py_${PYTHON_VERSION}/lib/libstdc++.so*
5765
fi
5866
}
5967

.ci/scripts/gather_benchmark_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def set_output(name: str, val: Any) -> None:
238238
try:
239239
with open(github_output, "a") as env:
240240
env.write(f"{name}={val}\n")
241-
except PermissionError:
241+
except (PermissionError, FileNotFoundError):
242242
# Fall back to printing in case of permission error in unit tests
243243
print(f"::set-output name={name}::{val}")
244244

.github/workflows/lint.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
# The generic Linux job chooses to use base env, not the one setup by the image
3232
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
3333
conda activate "${CONDA_ENV}"
34-
34+
3535
# For mypy linting, we need to first install executorch first so that
3636
# it builds the python package information.
3737
BUILD_TOOL="cmake"
@@ -74,6 +74,7 @@ jobs:
7474
docker-image: executorch-ubuntu-22.04-linter
7575
fetch-depth: 0
7676
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
77+
timeout: 90
7778
script: |
7879
FILES_NEEDS_FORMAT=$(/opt/google-java-format -n extension/android/src/main/java/org/pytorch/executorch/*.java \
7980
examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/*.java \

.github/workflows/pull.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ jobs:
212212
docker-image: executorch-ubuntu-22.04-clang12
213213
submodules: 'true'
214214
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
215-
timeout: 180
215+
timeout: 90
216216
script: |
217217
# The generic Linux job chooses to use base env, not the one setup by the image
218218
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
@@ -526,7 +526,7 @@ jobs:
526526
docker-image: executorch-ubuntu-22.04-clang12
527527
submodules: 'true'
528528
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
529-
timeout: 180
529+
timeout: 90
530530
script: |
531531
# The generic Linux job chooses to use base env, not the one setup by the image
532532
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,6 @@
6767
[submodule "backends/cadence/utils/FACTO"]
6868
path = backends/cadence/utils/FACTO
6969
url = https://github.com/pytorch-labs/FACTO.git
70+
[submodule "third-party/pocketfft"]
71+
path = third-party/pocketfft
72+
url = https://github.com/mreineck/pocketfft

CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "Build the Data Loader extension"
182182
OFF
183183
)
184184

185+
option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR "Build the Flat Tensor extension"
186+
OFF
187+
)
188+
185189
option(EXECUTORCH_BUILD_EXTENSION_MODULE "Build the Module extension" OFF)
186190

187191
option(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL "Build the Runner Util extension"
@@ -240,6 +244,9 @@ cmake_dependent_option(
240244
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
241245
)
242246

247+
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
248+
set(EXECUTORCH_BUILF_EXTENSION_DATA_LOADER ON)
249+
endif()
243250

244251
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
245252
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
@@ -694,6 +701,11 @@ if(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER)
694701
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/data_loader)
695702
endif()
696703

704+
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
705+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/flat_tensor)
706+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/flat_tensor/serialize)
707+
endif()
708+
697709
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
698710
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
699711
endif()

CONTRIBUTING.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,14 @@ must work with threading**
215215

216216
## Testing
217217

218+
### Running Tests Locally
219+
220+
CI is run automatically on all pull requests. However, if you want to run tests locally, here are some example commands (not exhaustive):
221+
222+
- The `sh test/build_size_test.sh` script will compile the C++runtime along with portable kernels.
223+
- The `test/run_oss_cpp_tests.sh` script will build and run C++ tests locally
224+
- Running `pytest` from the root directory will run Python tests locally.
225+
218226
### Writing Tests
219227
To help keep code quality high, ExecuTorch uses a combination of unit tests and
220228
end-to-end (e2e) tests. If you add a new feature or fix a bug, please add tests
@@ -229,8 +237,6 @@ If it's not clear how to add a test for your PR, take a look at the blame for
229237
the code you're modifying and find an author who has more context. Ask them
230238
for their help in the PR comments.
231239

232-
The `test/run_oss_cpp_tests.sh` script will build and run C++ tests locally.
233-
234240
### Continuous Integration
235241
See https://hud.pytorch.org/hud/pytorch/executorch/main for the current state of
236242
the CI (continuous integration) jobs. If `main` is broken, consider rebasing

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<div align="center">
88
<a href="https://github.com/pytorch/executorch/graphs/contributors"><img src="https://img.shields.io/github/contributors/pytorch/executorch?style=for-the-badge&color=blue" alt="Contributors"></a>
99
<a href="https://github.com/pytorch/executorch/stargazers"><img src="https://img.shields.io/github/stars/pytorch/executorch?style=for-the-badge&color=blue" alt="Stargazers"></a>
10-
<a href="https://discord.gg/MeacgB7A"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
10+
<a href="https://discord.gg/Dh43CKSAdc"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
1111
<a href="https://pytorch.org/executorch/stable/index.html"><img src="https://img.shields.io/badge/Documentation-000?logo=googledocs&logoColor=FFE165&style=for-the-badge" alt="Check out the documentation"></a>
1212
<hr>
1313
</div>
@@ -55,11 +55,11 @@ To get started you can:
5555
## Feedback and Engagement
5656

5757
We welcome any feedback, suggestions, and bug reports from the community to help
58-
us improve our technology. Check out the [Discussion Board](https://github.com/pytorch/executorch/discussions) or chat real time with us on [Discord](https://discord.gg/MeacgB7A)
58+
us improve our technology. Check out the [Discussion Board](https://github.com/pytorch/executorch/discussions) or chat real time with us on [Discord](https://discord.gg/Dh43CKSAdc)
5959

6060
## Contributing
6161

62-
We welcome contributions. To get started review the [guidelines](CONTRIBUTING.md) and chat with us on [Discord](https://discord.gg/MeacgB7A)
62+
We welcome contributions. To get started review the [guidelines](CONTRIBUTING.md) and chat with us on [Discord](https://discord.gg/Dh43CKSAdc)
6363

6464

6565
## Directory Structure

backends/arm/_passes/_debug_passes.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.devtools.visualization.visualization_utils import visualize_graph
8+
from executorch.exir import ExportedProgram
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class VisualizePass(ExportPass):
13+
"""
14+
This pass visualizes the graph at the point of insertion in the pass manager
15+
"""
16+
17+
def __init__(self, exported_program: ExportedProgram) -> None:
18+
super().__init__()
19+
self.exported_program = exported_program
20+
21+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
22+
visualize_graph(graph_module, self.exported_program)
23+
return PassResult(graph_module, False)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
123123
self.add_pass(FuseQuantizedActivationPass())
124124
self.add_pass(RemoveGetItemPass())
125125
self.add_pass(ConvertSplitToSlicePass())
126+
self.add_pass(FuseBatchnorm2DPass(exported_program))
126127
self.add_pass(ConvertMmToBmmPass())
127128
self.add_pass(DecomposeLinearPass())
128129
self.add_pass(DecomposeBatchNormPass())
@@ -132,7 +133,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
132133
self.add_pass(ConvertMeanDimToAveragePoolPass())
133134
self.add_pass(DecomposeDivPass())
134135
self.add_pass(DecomposeSoftmaxesPass())
135-
self.add_pass(FuseBatchnorm2DPass(exported_program))
136136

137137
self.add_pass(AnnotateDecomposedMatmulPass())
138138
self.add_pass(QuantizeOperatorArguments())

backends/arm/_passes/decompose_select.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,13 @@ def call(self, graph_module: torch.fx.GraphModule):
3737
rank = len(input_node.meta["val"].size())
3838
dim = dim % rank if dim < 0 else dim
3939
index = index % rank if index < 0 else index
40-
dim_list = list(range(rank))
4140

4241
with graph_module.graph.inserting_before(node):
4342
slice_node = create_node(
4443
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
4544
)
4645
squeeze_node = create_node(
47-
graph_module.graph, squeeze_op, (slice_node, dim_list)
46+
graph_module.graph, squeeze_op, (slice_node, [dim])
4847
)
4948

5049
node.replace_all_uses_with(squeeze_node)

backends/arm/arm_partitioner.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77

88
import logging
99
import os
10-
from typing import Callable, final, List, Optional, Tuple
10+
from typing import Callable, final, List, Optional, Sequence, Tuple
1111

1212
import torch
1313
from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined]
1414
ArmBackend,
1515
) # usort: skip
1616
from executorch.backends.arm.operator_support.tosa_supported_operators import (
17-
TOSASupportedOperators,
17+
tosa_support_factory,
1818
)
1919
from executorch.backends.arm.tosa_specification import TosaSpecification
2020
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -27,6 +27,8 @@
2727
from executorch.exir.dialects._ops import ops as exir_ops
2828
from torch.export.exported_program import ExportedProgram
2929
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
30+
from torch.fx.passes.operator_support import OperatorSupportBase
31+
3032

3133
logger = logging.getLogger(__name__)
3234
logger.setLevel(logging.WARNING)
@@ -54,8 +56,13 @@ def is_dequant_node(node: torch.fx.node.Node) -> bool:
5456

5557
@final
5658
class ArmPartitioner(Partitioner):
57-
def __init__(self, compile_spec: List[CompileSpec]) -> None:
59+
def __init__(
60+
self,
61+
compile_spec: List[CompileSpec],
62+
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
63+
) -> None:
5864
self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec)
65+
self.additional_checks = additional_checks
5966

6067
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
6168
# Run the CapabilityBasedPartitioner to return the largest possible
@@ -72,7 +79,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
7279

7380
capability_partitioner = CapabilityBasedPartitioner(
7481
exported_program.graph_module,
75-
TOSASupportedOperators(tosa_spec),
82+
tosa_support_factory(tosa_spec, self.additional_checks),
7683
allows_single_node_partition=True,
7784
)
7885
partition_list = capability_partitioner.propose_partitions()
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
77

8-
from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa
8+
from . import ( # noqa
9+
convolution_support,
10+
pool_2d_support,
11+
reduce_sum_support,
12+
right_shift_support,
13+
to_copy_support,
14+
tosa_supported_operators,
15+
)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast
7+
8+
import torch
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
@register_tosa_support_check
19+
class ConvolutionSupported(SupportedTOSAOperatorCheck):
20+
targets = [exir_ops.edge.aten.convolution.default]
21+
22+
tosa_specs = [
23+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
24+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
]
26+
27+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
28+
29+
# Not implemented
30+
transposed = cast(bool, node.args[6])
31+
output_padding = cast(list[int], node.args[7])
32+
if transposed:
33+
return False
34+
35+
for pad in output_padding:
36+
if pad != 0:
37+
return False
38+
39+
# Hardware specific constraints
40+
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
41+
return True
42+
else:
43+
return self._is_node_supported_u55(node)
44+
45+
def _is_node_supported_u55(self, node: fx.Node):
46+
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""
47+
48+
shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
49+
shape_out = node.meta["val"].shape
50+
kernel = cast(fx.Node, node.args[1]).meta["val"].shape
51+
group = cast(int, node.args[8])
52+
53+
C_in = shape_in[1]
54+
C_out = shape_out[1]
55+
if (C_in == group) and (C_out % C_in) == 0:
56+
# Depthwise convolution
57+
for dim in shape_in[1:]:
58+
if not 1 <= dim <= 65536:
59+
return False
60+
else:
61+
# Convolution
62+
if not 1 <= C_in <= 65536:
63+
return False
64+
65+
kernel_w = kernel[2]
66+
kernel_h = kernel[3] if len(kernel) > 3 else 1
67+
# Kernel condition misses constraint on sum of absolute weights
68+
if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096:
69+
return False
70+
71+
if not self._stride_condition(node):
72+
return False
73+
74+
return True
75+
76+
def _stride_condition(self, node: fx.Node) -> bool:
77+
"""This condition is somewhat complex but boils down
78+
to not supporting stride > 3, unless we have some special conditions.
79+
This condition is a simplified, relaxed version of the hardware constraint,
80+
since the actual constraint requires information not available
81+
here (without a lot of work).
82+
83+
This means that we might accept ops that are not actually supported.
84+
"""
85+
strides = cast(list[int], node.args[3])
86+
has_padding = any(pad > 0 for pad in cast(list[int], node.args[4]))
87+
dilations = cast(list[int], node.args[5])
88+
if len(dilations) == 1:
89+
dilations = [dilations[0]] * 2
90+
if len(strides) == 1:
91+
strides = [strides[0]] * 2
92+
93+
for stride, dilation in zip(strides, dilations):
94+
stride_condition = 1 <= stride <= 3
95+
dilation_condition = (not has_padding) and (dilation == 1)
96+
if (not stride_condition) and (not dilation_condition):
97+
return False
98+
99+
return True

0 commit comments

Comments
 (0)