Skip to content

Commit be78434

Browse files
committed
Update
[ghstack-poisoned]
2 parents ab96fd7 + a01571f commit be78434

File tree

95 files changed

+2951
-469
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+2951
-469
lines changed

.ci/scripts/unittest-linux.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ PYTHON_EXECUTABLE=python ./examples/models/llama3_2_vision/install_requirements.
3232
if [[ "$BUILD_TOOL" == "cmake" ]]; then
3333
.ci/scripts/unittest-linux-cmake.sh
3434
elif [[ "$BUILD_TOOL" == "buck2" ]]; then
35-
.ci/scripts/unittest-linux-buck2.sh
35+
.ci/scripts/unittest-buck2.sh
3636
else
3737
echo "Unknown build tool $BUILD_TOOL"
3838
exit 1

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR "Build the Flat Tensor extension"
186186
OFF
187187
)
188188

189+
option(EXECUTORCH_BUILD_EXTENSION_LLM "Build the LLM extension"
190+
OFF
191+
)
192+
189193
option(EXECUTORCH_BUILD_EXTENSION_MODULE "Build the Module extension" OFF)
190194

191195
option(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL "Build the Runner Util extension"
@@ -718,6 +722,10 @@ if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
718722
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/flat_tensor/serialize)
719723
endif()
720724

725+
if(EXECUTORCH_BUILD_EXTENSION_LLM)
726+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/tokenizer)
727+
endif()
728+
721729
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
722730
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
723731
endif()

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

55
import logging
6-
from typing import List, Optional
6+
from typing import Callable, List, Optional, Tuple
77

88
import coremltools as ct
99

@@ -104,3 +104,17 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
104104
return PartitionResult(
105105
tagged_exported_program=exported_program, partition_tags=partition_tags
106106
)
107+
108+
def ops_to_not_decompose(
109+
self, ep: ExportedProgram
110+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
111+
do_not_decompose = []
112+
op_support = OperatorsSupportedForCoreMLBackend()
113+
for node in ep.graph.nodes:
114+
if (
115+
node.op == "call_function"
116+
and isinstance(node.target, torch._ops.OpOverload)
117+
and op_support.is_node_supported(None, node)
118+
):
119+
do_not_decompose.append(node.target)
120+
return do_not_decompose, None

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1515
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
16+
from executorch.exir.backend.utils import format_delegated_graph
1617

1718

1819
class TestCoreMLPartitioner(unittest.TestCase):
@@ -79,6 +80,50 @@ def test_vit_skip_conv(self):
7980
"getitem",
8081
]
8182

83+
def test_ops_to_not_decompose(self):
84+
class Model(torch.nn.Module):
85+
def forward(self, q, k, v, mask):
86+
return torch.ops.aten.scaled_dot_product_attention.default(
87+
q, k, v, attn_mask=mask
88+
)
89+
90+
model = Model()
91+
model.eval()
92+
93+
batch_size = 1
94+
n_heads = 12
95+
seq_len = 1
96+
max_seq_length = 32
97+
embedding_dim = 16
98+
q = torch.randn(batch_size, n_heads, seq_len, embedding_dim)
99+
k = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)
100+
v = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)
101+
mask = torch.randn(seq_len, max_seq_length)
102+
example_inputs = (q, k, v, mask)
103+
ep = torch.export.export(model, example_inputs)
104+
coreml_partitioner = CoreMLPartitioner()
105+
106+
# Using to_edge_transform_and_lower, we expect SDPA will be preserved and show up in delegated graph
107+
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
108+
ep, partitioner=[coreml_partitioner]
109+
)
110+
self.assertTrue(
111+
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
112+
in format_delegated_graph(
113+
edge_program_manager.exported_program().graph_module
114+
)
115+
)
116+
117+
# Using to_edge flow, we expect SDPA will be decomposed and not show up in delegated graph
118+
edge_program_manager2 = executorch.exir.to_edge(ep)
119+
edge_program_manager2.to_backend(coreml_partitioner)
120+
self.assertTrue(
121+
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
122+
not in format_delegated_graph(
123+
edge_program_manager2.exported_program().graph_module
124+
)
125+
)
126+
82127
def test_buffer(self):
83128
embedding_dim = 3
84129
max_seq_len = 2
@@ -129,4 +174,5 @@ def forward(self, q, k_val, input_pos):
129174
test_runner = TestCoreMLPartitioner()
130175
test_runner.test_add_sub_skip_mm()
131176
test_runner.test_vit_skip_conv()
177+
test_runner.test_ops_to_not_decompose()
132178
test_runner.test_buffer()

backends/arm/TARGETS

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
44
python_library(
55
name = "arm_partitioner",
66
srcs = [
7-
"arm_partitioner.py",
7+
"ethosu_backend.py",
8+
"ethosu_partitioner.py",
9+
"tosa_backend.py",
10+
"tosa_partitioner.py",
811
],
912
typing = True,
1013
deps = [

backends/arm/_passes/arm_pass_manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
5353
FuseQuantizedActivationPass,
5454
)
55+
from executorch.backends.arm._passes.insert_rescales_pass import InsertRescalePass
5556
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
5657
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
5758
KeepDimsFalseToSqueezePass,
@@ -75,6 +76,10 @@
7576
UnsqueezeScalarPlaceholdersPass,
7677
)
7778
from executorch.backends.arm.tosa_specification import TosaSpecification
79+
80+
from executorch.backends.transforms.replace_scalar_with_tensor import (
81+
ReplaceScalarWithTensorArgPass,
82+
)
7883
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
7984
from executorch.exir import ExportedProgram
8085
from executorch.exir.pass_manager import PassManager
@@ -100,6 +105,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
100105
self.add_pass(ConvertMeanDimToAveragePoolPass())
101106
self.add_pass(ConvertFullLikeToFullPass())
102107

108+
self.add_pass(ReplaceScalarWithTensorArgPass())
103109
self.add_pass(AnnotateDecomposedMatmulPass())
104110
self.add_pass(QuantizeOperatorArguments())
105111
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -119,11 +125,11 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
119125
self.add_pass(ConvertSqueezesToViewPass())
120126

121127
self.add_pass(AnnotateChannelsLastDimOrder())
122-
128+
self.add_pass(InsertRescalePass())
123129
return self._transform(exported_program.graph_module)
124130

125131
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
126-
132+
self.add_pass(ReplaceScalarWithTensorArgPass())
127133
self.add_pass(FuseQuantizedActivationPass())
128134
self.add_pass(RemoveGetItemPass())
129135
self.add_pass(ConvertSplitToSlicePass())
@@ -157,6 +163,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
157163
self.add_pass(ConvertSqueezesToViewPass())
158164

159165
self.add_pass(AnnotateChannelsLastDimOrder())
166+
self.add_pass(InsertRescalePass())
160167

161168
return self._transform(exported_program.graph_module)
162169

@@ -173,6 +180,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
173180

174181
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
175182
self.add_pass(ScalarsToAttributePass())
183+
self.add_pass(ReplaceScalarWithTensorArgPass())
176184
self.add_pass(DecomposeLayerNormPass())
177185
self.add_pass(DecomposeVarPass())
178186
self.add_pass(DecomposeMeanDimPass())

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
131131
n = cast(Node, n)
132132
if n.op != "call_function":
133133
continue
134+
# Don't fold chains of quant-ops into each other.
135+
if n.target in (q_op, dq_op):
136+
continue
134137

135138
# Make sure we haven't already set qparams meta information on the node
136139
assert "input_qparams" not in n.meta.keys()

backends/arm/_passes/fuse_batchnorm2d_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def try_set_param(
114114
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
115115
bn_bias_node, fused_conv_bias
116116
):
117+
# pyre-ignore[60]
117118
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
118119
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
119120
conv.args = conv_args
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from copy import copy
8+
from typing import cast
9+
10+
import torch
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch import Tensor
15+
from torch.fx import GraphModule, Node
16+
from torch.library import custom_op, register_fake
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@custom_op("tosa::_rescale", mutates_args=()) # type: ignore[misc]
22+
def rescale(
23+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
24+
) -> Tensor:
25+
logger.warning(
26+
"Ran default implementation of tosa::_rescale."
27+
"This op is meant to always be inserted inside a partition and a correct default implementation is not implemented."
28+
)
29+
# Clone is needed to not return reference when rescaling to same dtype.
30+
# This is a neccessary requirement for non-mutating custom ops.
31+
return x.to(dtype=dtype).clone()
32+
33+
34+
@register_fake("tosa::_rescale") # type: ignore[misc]
35+
def rescale_fake(
36+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
37+
) -> Tensor:
38+
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
39+
Additionally validates TOSA constraints of a RESCALE op.
40+
"""
41+
if not (dtype == torch.int32 or dtype == torch.int8):
42+
raise NotImplementedError(
43+
"tosa::rescale currently only supports int32 and int8."
44+
)
45+
if dtype == torch.int32 and out_zp != 0:
46+
raise ValueError(
47+
"TOSA requires output_zp to be zero when the output dtype is int32."
48+
)
49+
if x.dtype == torch.int32 and in_zp != 0:
50+
raise ValueError(
51+
"TOSA requires input_zp to be zero when the input dtype is int32."
52+
)
53+
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
54+
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")
55+
if dtype == torch.int8 and not -128 <= out_zp <= 127:
56+
raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.")
57+
58+
return x.to(dtype=dtype).clone()
59+
60+
61+
class InsertRescalePass(ExportPass):
62+
"""Finds patterns of dq -> q, and replaces them
63+
with passthrough_to_tosa::rescales.
64+
65+
Does not garantuee that the dtypes and zero points are valid
66+
in TOSA, that is the job of the quantization annotator that
67+
produced the dq and q nodes. The TOSA constraints are validated
68+
in the fake implementation of passthrough_to_tosa:rescale.
69+
"""
70+
71+
def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
72+
dq_args = QuantArgs.from_operator(node.target, node.args)
73+
q_args = QuantArgs.from_operator(user.target, user.args)
74+
new_scale = dq_args.scale / q_args.scale
75+
76+
with graph_module.graph.inserting_before(node):
77+
rescale_node = create_node(
78+
graph_module.graph,
79+
torch.ops.tosa._rescale.default,
80+
(
81+
node.all_input_nodes[0],
82+
q_args.dtype,
83+
new_scale,
84+
dq_args.zp,
85+
q_args.zp,
86+
),
87+
)
88+
rescale_node.meta = copy(user.meta)
89+
user.replace_all_uses_with(rescale_node)
90+
graph_module.graph.erase_node(user)
91+
92+
def call(self, graph_module: GraphModule) -> PassResult:
93+
modified = False
94+
for node in graph_module.graph.nodes:
95+
node = cast(Node, node)
96+
97+
if node.target is not dq_op:
98+
continue
99+
# Copy users since we remove them while iterating, modyfing the node.users list.
100+
for user in copy(node.users):
101+
if user.target is q_op:
102+
self.fold_dq_q_to_rescale(node, user, graph_module)
103+
modified = True
104+
if len(node.users) == 0:
105+
graph_module.graph.erase_node(node)
106+
107+
graph_module = super().call(graph_module).graph_module
108+
graph_module.recompile()
109+
return PassResult(graph_module, modified)

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,17 @@ def call(self, graph_module: GraphModule) -> PassResult:
7676
new_args.append(get_attr_node)
7777
n.args = tuple(new_args)
7878

79+
# Replace rsub.Scalar with sub.Tensor as retracing will fail otherwise
80+
if n.target == torch.ops.aten.rsub.Scalar:
81+
with graph_module.graph.inserting_after(n):
82+
reversed_args = (n.args[1], n.args[0])
83+
sub = graph_module.graph.create_node(
84+
"call_function", torch.ops.aten.sub.Tensor, reversed_args, {}
85+
)
86+
n.replace_all_uses_with(sub)
87+
sub.meta["val"] = n.meta["val"]
88+
graph_module.graph.erase_node(n)
89+
7990
graph_module.recompile()
91+
graph_module = super().call(graph_module).graph_module
8092
return PassResult(graph_module, True)

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9+
bitwise_support,
910
convolution_support,
1011
pool_2d_support,
1112
reduce_sum_support,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch.fx as fx
7+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
8+
register_tosa_support_check,
9+
SupportedTOSAOperatorCheck,
10+
)
11+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
@register_tosa_support_check
16+
class BitwiseSupported(SupportedTOSAOperatorCheck):
17+
targets = [
18+
exir_ops.edge.aten.bitwise_and.Tensor,
19+
exir_ops.edge.aten.bitwise_or.Tensor,
20+
exir_ops.edge.aten.bitwise_xor.Tensor,
21+
]
22+
23+
tosa_specs = [
24+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
25+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
26+
]
27+
28+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
29+
# U55 case, Vela 4.2.0 (25.02 release)
30+
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
31+
return False
32+
33+
return True

0 commit comments

Comments
 (0)