Skip to content

Commit ffd1cf4

Browse files
committed
Update base for Update on "[ExecuTorch][Llama] Split custom sdpa op and kv cache"
Summary: This enables us to do more easier module swap with model definitions from torchtune Test Plan: CI Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D67914056](https://our.internmc.facebook.com/intern/diff/D67914056) [ghstack-poisoned]
2 parents 4093a5d + d1b33cb commit ffd1cf4

File tree

165 files changed

+3734
-2006
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

165 files changed

+3734
-2006
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
mpmath==1.3.0
2-
numpy==1.21.3; python_version == '3.10'
3-
numpy==1.23.2; python_version == '3.11'
4-
numpy; python_version >= '3.12'
2+
numpy==2.0.0; python_version >= '3.10'
53
PyYAML==6.0.1
64
ruamel.yaml==0.17.32
75
sympy==1.12
86
timm==0.6.13
97
tomli==2.0.1
108
torchsr==1.0.4
11-
transformers==4.38.0
9+
transformers==4.47.1
1210
zstd==1.5.5.1
13-
pandas==2.0.3; python_version == '3.10'
14-
pandas; python_version >= '3.11'
11+
pandas==2.2.2; python_version >= '3.10'
1512
pytest==7.2.0
1613
pytest-cov==4.1.0
1714
expecttest==0.1.6
@@ -24,7 +21,7 @@ sphinx-gallery==0.14.0
2421
breathe==4.34.0
2522
exhale==0.2.3
2623
docutils==0.16
27-
matplotlib==3.7.2
24+
matplotlib==3.9.4
2825
# PyTorch Theme
2926
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
3027
myst-parser==0.18.1

.ci/scripts/build-qnn-sdk.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/bin/bash
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
# All rights reserved.
45
#
56
# This source code is licensed under the BSD-style license found in the
@@ -11,10 +12,16 @@ set -o xtrace
1112
build_qnn_backend() {
1213
echo "Start building qnn backend."
1314
export ANDROID_NDK_ROOT=/opt/ndk
14-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
15+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
1516
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." && pwd)"
1617

17-
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release
18+
# Workaround to avoid issues around missing flatccrt library (depending on the
19+
# number of jobs used), see issue #7300:
20+
# Build twice (second time with `--no_clean`) to make sure libflatccrt.a is
21+
# available.
22+
# TODO: Remove this workaround once the underlying issue is fixed.
23+
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release || \
24+
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release --no_clean
1825
}
1926

2027
set_up_aot() {

.ci/scripts/setup-qnn-deps.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ install_qnn() {
1616
QNN_INSTALLATION_DIR=/tmp/qnn
1717
mkdir -p "${QNN_INSTALLATION_DIR}"
1818

19-
curl -Lo /tmp/v2.25.0.24.07.28.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.25.0.240728.zip"
19+
curl -Lo /tmp/v2.28.0.24.10.29.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.28.0.241029.zip"
2020
echo "Finishing downloading qnn sdk."
21-
unzip -qo /tmp/v2.25.0.24.07.28.zip -d /tmp
21+
unzip -qo /tmp/v2.28.0.24.10.29.zip -d /tmp
2222
echo "Finishing unzip qnn sdk."
2323

2424

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ echo "COREML option ${COREML}"
121121
if [[ "${MODE}" =~ .*qnn.* ]]; then
122122
QNN=ON
123123
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)"
124-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
124+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
125125
export LD_LIBRARY_PATH="${QNN_SDK_ROOT}/lib/x86_64-linux-clang"
126126
export PYTHONPATH=".."
127127
cp schema/program.fbs exir/_serialize/program.fbs

.github/pytorch-probot.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# The schema is from https://github.com/pytorch/pytorch/blob/main/.github/pytorch-probot.yml
2+
tracking_issue: 7679
23
ciflow_push_tags:
34
- ciflow/android
45
- ciflow/apple

.github/workflows/android-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ jobs:
260260
--output_name="${OUT_ET_MODEL_NAME}.pte"
261261
ls -lh "${OUT_ET_MODEL_NAME}.pte"
262262
elif [[ ${{ matrix.config }} == "llama3_qnn_htp" ]]; then
263-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
263+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
264264
export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/
265265
export PYTHONPATH=$(pwd)/..
266266
@@ -347,7 +347,7 @@ jobs:
347347
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
348348
349349
export ANDROID_ABIS="arm64-v8a"
350-
PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728 bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
350+
PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029 bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
351351
352352
# Let's see how expensive this job is, we might want to tone it down by running it periodically
353353
benchmark-on-device:

backends/apple/coreml/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ class Model(torch.nn.Module):
9393
source_model = Model()
9494
example_inputs = (torch.randn((1, 3, 256, 256)), )
9595

96-
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
96+
pre_autograd_aten_dialect = export_for_training(source_model, example_inputs).module()
9797

9898
quantization_config = LinearQuantizerConfig.from_dict(
9999
{
100100
"global_config": {
101101
"quantization_scheme": QuantizationScheme.symmetric,
102-
"activation_dtype": torch.uint8,
103-
"weight_dtype": torch.int8,
102+
"activation_dtype": torch.quint8,
103+
"weight_dtype": torch.qint8,
104104
"weight_per_channel": True,
105105
}
106106
}

backends/apple/coreml/scripts/install_requirements.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel
4747

4848
echo "${green}ExecuTorch: Installing coremltools."
4949
pip install "$COREMLTOOLS_DIR_PATH"
50-
# CoreMLTools have started supporting numpy 2.0,
51-
# but ExecuTorch example model test env is still using older transformers,
52-
# so for now we will need to downgrade numpy to 1.x
53-
# TODO: Remove this numpy downgrade once later transformers starts to be used
54-
pip install numpy==1.26.4
50+
5551
STATUS=$?
5652
if [ $STATUS -ne 0 ]; then
5753
echo "${red}ExecuTorch: Failed to install coremltools."

backends/arm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ backends/arm/test/setup_testing.sh
119119
The you can run the tests with
120120

121121
```
122-
pytest -c /dev/null -v -n auto backends/arm/test --arm_quantize_io --arm_run_corstoneFVP
122+
pytest -c /dev/null -v -n auto backends/arm/test --arm_run_corstoneFVP
123123
```
124124

125125
### Code coverage

backends/arm/_passes/arm_pass_manager.py

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
# pyre-unsafe
99

10-
import torch
1110
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
1211
AnnotateChannelsLastDimOrder,
1312
)
@@ -28,6 +27,7 @@
2827
)
2928
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
3029
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
30+
from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
3131
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
3232
DecomposeSoftmaxesPass,
3333
)
@@ -46,7 +46,7 @@
4646
)
4747
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
4848
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
49-
ConvertMeanDimToAveragePool,
49+
ConvertMeanDimToAveragePoolPass,
5050
)
5151
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
5252
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
@@ -60,92 +60,98 @@
6060
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
6161
UnsqueezeScalarPlaceholdersPass,
6262
)
63+
from executorch.backends.arm.tosa_specification import TosaSpecification
6364
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6465
from executorch.exir import ExportedProgram
65-
from executorch.exir.backend.compile_spec_schema import CompileSpec
66-
from executorch.exir.dialects._ops import ops as exir_ops
6766
from executorch.exir.pass_manager import PassManager
67+
from torch.fx import GraphModule
6868

6969

7070
class ArmPassManager(PassManager):
7171

72-
def _transform(self, graph_module: torch.fx.GraphModule):
72+
def __init__(self, tosa_spec: TosaSpecification) -> None:
73+
self.tosa_spec = tosa_spec
74+
super().__init__()
75+
76+
def _transform(self, graph_module: GraphModule):
7377
return self(graph_module).graph_module
7478

75-
def transform_to_backend_pipeline(
76-
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
77-
):
78-
"""Apply passes before transforming program to backend"""
79+
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
7980
self.add_pass(FuseQuantizedActivationPass())
81+
self.add_pass(RemoveGetItemPass())
82+
self.add_pass(ConvertSplitToSlicePass())
83+
self.add_pass(ConvertMmToBmmPass())
8084
self.add_pass(DecomposeLinearPass())
85+
self.add_pass(ConvertMeanDimToAveragePoolPass())
86+
87+
self.add_pass(AnnotateDecomposedMatmulPass())
88+
self.add_pass(QuantizeFullArgument())
89+
self.add_pass(FoldAndAnnotateQParamsPass())
90+
self.add_pass(RetraceFoldedDtypesPass())
91+
self.add_pass(InsertTableOpsPass(exported_program))
92+
93+
self.add_pass(RemoveClonePass())
94+
self.add_pass(SizeAdjustConv2DPass())
95+
self.add_pass(ConvertExpandCopyToRepeatPass())
96+
self.add_pass(UnsqueezeBeforeRepeatPass())
97+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
98+
self.add_pass(CastInt64ToInt32Pass(exported_program))
99+
self.add_pass(MatchArgRanksPass(exported_program))
100+
self.add_pass(KeepDimsFalseToSqueezePass())
101+
self.add_pass(Conv1dUnsqueezePass(exported_program))
102+
self.add_pass(DecomposeSelectPass())
103+
104+
self.add_pass(AnnotateChannelsLastDimOrder())
105+
106+
return self._transform(exported_program.graph_module)
107+
108+
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
109+
110+
self.add_pass(FuseQuantizedActivationPass())
81111
self.add_pass(RemoveGetItemPass())
112+
self.add_pass(ConvertSplitToSlicePass())
113+
self.add_pass(ConvertMmToBmmPass())
114+
self.add_pass(DecomposeLinearPass())
82115
self.add_pass(DecomposeLayerNormPass())
83116
self.add_pass(DecomposeVarPass())
84-
self.add_pass(ConvertMeanDimToAveragePool())
85117
self.add_pass(DecomposeMeanDimPass())
86-
self.add_pass(ConvertSplitToSlicePass())
87-
self.add_pass(ConvertMmToBmmPass())
88-
# TODO MLETORCH-558
118+
self.add_pass(ConvertMeanDimToAveragePoolPass())
119+
self.add_pass(DecomposeDivPass())
120+
self.add_pass(DecomposeSoftmaxesPass())
121+
89122
self.add_pass(AnnotateDecomposedMatmulPass())
90123
self.add_pass(QuantizeFullArgument())
91-
self.add_pass(
92-
FoldAndAnnotateQParamsPass(
93-
[
94-
exir_ops.edge.aten.minimum.default,
95-
exir_ops.edge.aten.maximum.default,
96-
exir_ops.edge.aten.add.Tensor,
97-
exir_ops.edge.aten.avg_pool2d.default,
98-
exir_ops.edge.aten.bmm.default,
99-
exir_ops.edge.aten.cat.default,
100-
exir_ops.edge.aten.convolution.default,
101-
exir_ops.edge.aten.clone.default,
102-
exir_ops.edge.aten.exp.default,
103-
exir_ops.edge.aten.expand_copy.default,
104-
exir_ops.edge.aten.full.default,
105-
exir_ops.edge.aten.hardtanh.default,
106-
exir_ops.edge.aten.log.default,
107-
exir_ops.edge.aten.max_pool2d.default,
108-
exir_ops.edge.aten.mul.Tensor,
109-
exir_ops.edge.aten.permute_copy.default,
110-
exir_ops.edge.aten.reciprocal.default,
111-
exir_ops.edge.aten.relu.default,
112-
exir_ops.edge.aten.repeat.default,
113-
exir_ops.edge.aten.rsqrt.default,
114-
exir_ops.edge.aten.select_copy.int,
115-
exir_ops.edge.aten.sigmoid.default,
116-
exir_ops.edge.aten.slice_copy.Tensor,
117-
exir_ops.edge.aten.squeeze_copy.dims,
118-
exir_ops.edge.aten.sub.Tensor,
119-
exir_ops.edge.aten.sum.dim_IntList,
120-
exir_ops.edge.aten.tanh.default,
121-
exir_ops.edge.aten.unsqueeze_copy.default,
122-
exir_ops.edge.aten.upsample_nearest2d.vec,
123-
exir_ops.edge.aten.view_copy.default,
124-
]
125-
)
126-
)
124+
self.add_pass(FoldAndAnnotateQParamsPass())
127125
self.add_pass(RetraceFoldedDtypesPass())
128126
self.add_pass(InsertTableOpsPass(exported_program))
127+
128+
self.add_pass(RemoveClonePass())
129+
self.add_pass(SizeAdjustConv2DPass())
129130
self.add_pass(ConvertExpandCopyToRepeatPass())
130131
self.add_pass(UnsqueezeBeforeRepeatPass())
131-
self.add_pass(CastInt64ToInt32Pass(exported_program))
132132
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
133-
self.add_pass(SizeAdjustConv2DPass())
134-
self.add_pass(RemoveClonePass())
133+
self.add_pass(CastInt64ToInt32Pass(exported_program))
135134
self.add_pass(MatchArgRanksPass(exported_program))
136-
self.add_pass(DecomposeDivPass())
137135
self.add_pass(KeepDimsFalseToSqueezePass())
138136
self.add_pass(Conv1dUnsqueezePass(exported_program))
139-
self.add_pass(DecomposeSoftmaxesPass())
140-
for spec in compile_spec:
141-
if spec.key == "permute_memory_format":
142-
memory_format = spec.value.decode()
143-
if memory_format == "nhwc":
144-
self.add_pass(AnnotateChannelsLastDimOrder())
137+
self.add_pass(DecomposeSelectPass())
138+
139+
self.add_pass(AnnotateChannelsLastDimOrder())
145140

146141
return self._transform(exported_program.graph_module)
147142

148-
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
143+
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
144+
"""Apply passes before transforming program to backend"""
145+
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
146+
return self._tosa_080_BI_pipeline(exported_program)
147+
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
148+
return self._tosa_080_MI_pipeline(exported_program)
149+
else:
150+
raise NotImplementedError(
151+
f"No pass pipeline implemented for {self.tosa_spec=}"
152+
)
153+
154+
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
149155
self.add_pass(ScalarsToAttributePass())
150156
self.add_pass(DecomposeLayerNormPass())
151157
self.add_pass(DecomposeVarPass())

backends/arm/_passes/cast_int64_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -17,6 +17,10 @@
1717

1818

1919
class CastInt64ToInt32Pass(ExportPass):
20+
"""
21+
Cast int64 buffers to int32 if the int64 data is in int32 range.
22+
"""
23+
2024
def __init__(self, exported_program: torch.export.ExportedProgram):
2125
super(CastInt64ToInt32Pass, self).__init__()
2226
self.exported_program = exported_program
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
14+
15+
class DecomposeSelectPass(ExportPass):
16+
"""
17+
This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1)
18+
"""
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
for node in graph_module.graph.nodes:
22+
23+
if node.op != "call_function":
24+
continue
25+
26+
if node.target in (
27+
exir_ops.edge.aten.select.int,
28+
exir_ops.edge.aten.select_copy.int,
29+
):
30+
slice_op = exir_ops.edge.aten.slice_copy.Tensor
31+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
32+
else:
33+
continue
34+
35+
input_node, dim, index = node.args
36+
37+
rank = len(input_node.meta["val"].size())
38+
dim = dim % rank if dim < 0 else dim
39+
index = index % rank if index < 0 else index
40+
dim_list = list(range(rank))
41+
42+
with graph_module.graph.inserting_before(node):
43+
slice_node = create_node(
44+
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
45+
)
46+
squeeze_node = create_node(
47+
graph_module.graph, squeeze_op, (slice_node, dim_list)
48+
)
49+
50+
node.replace_all_uses_with(squeeze_node)
51+
graph_module.graph.erase_node(node)
52+
53+
graph_module.graph.eliminate_dead_code()
54+
graph_module.recompile()
55+
graph_module = super().call(graph_module).graph_module
56+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)