Skip to content

Commit 7057315

Browse files
authored
Merge branch 'main' into jz/update-android-readme
2 parents 85bb6c3 + d0464f8 commit 7057315

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+949
-568
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
1010
from .arm_pass import ArmPass # noqa
11+
from .broadcast_args_pass import BroadcastArgsPass # noqa
1112
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1213
from .cast_to_int32_pass import CastToInt32Pass # noqa
1314
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from executorch.backends.arm._passes import (
1111
AnnotateChannelsLastDimOrder,
1212
AnnotateDecomposedMatmulPass,
13+
BroadcastArgsPass,
1314
CastInt64BuffersToInt32Pass,
1415
CastToInt32Pass,
1516
ComputeConstantOpsAOT,
@@ -104,6 +105,8 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
104105
self.add_pass(RetraceFoldedDtypesPass())
105106
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
106107
self.add_pass(MatchArgRanksPass(exported_program))
108+
if self.tosa_spec.is_U55_subset:
109+
self.add_pass(BroadcastArgsPass())
107110
self.add_pass(ComputeConstantOpsAOT(exported_program))
108111

109112
self.add_pass(RemoveClonePass())
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
8+
from executorch.backends.arm._passes.arm_pass_utils import (
9+
create_node,
10+
get_first_fake_tensor,
11+
)
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
from executorch.exir.pass_base import PassResult
16+
from torch.fx import GraphModule, Node
17+
18+
19+
class BroadcastArgsPass(ArmPass):
20+
"""
21+
Pass to manually broadcast arguments by inserting repeats.
22+
This is done when more than one arg needs broadcasting.
23+
"""
24+
25+
targeted_ops = {
26+
exir_ops.edge.aten.add.Tensor,
27+
exir_ops.edge.aten.sub.Tensor,
28+
# mul is indirectly targeting div as div is decompsed to reciprocal + mul
29+
exir_ops.edge.aten.mul.Tensor,
30+
}
31+
32+
def call(self, graph_module: GraphModule) -> PassResult:
33+
for node in graph_module.graph.nodes:
34+
if node.op != "call_function" or node.target not in self.targeted_ops:
35+
continue
36+
37+
output_shape = get_first_fake_tensor(node).shape
38+
nbr_of_broacasts = 0
39+
for arg in node.args:
40+
if not isinstance(arg, Node):
41+
continue
42+
43+
shape = get_first_fake_tensor(arg).shape
44+
if shape != output_shape:
45+
nbr_of_broacasts += 1
46+
if nbr_of_broacasts > 1:
47+
multiples = [
48+
int(output_shape[d] / shape[d])
49+
for d in range(len(output_shape))
50+
]
51+
with graph_module.graph.inserting_before(node):
52+
repeat = create_node(
53+
graph_module.graph,
54+
exir_ops.edge.aten.repeat.default,
55+
args=(arg, multiples),
56+
kwargs={},
57+
from_node=node,
58+
)
59+
node.replace_input_with(arg, repeat)
60+
61+
graph_module.recompile()
62+
graph_module = super().call(graph_module).graph_module
63+
return PassResult(graph_module, True)

backends/arm/arm_vela.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False)
7373
np_path = os.path.join(tmpdir, "output", "out_vela.npz")
7474
else:
7575
np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
76-
blocks = b""
7776

77+
blocks = b""
7878
with np.load(np_path, allow_pickle=False) as data:
7979
# Construct our modified output_blocks with data in a form easily
8080
# digested on the device side
@@ -92,7 +92,7 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False)
9292
if not isinstance(data["scratch_shape"][0], np.int64):
9393
raise RuntimeError("Expected scratch to be int64")
9494
block_length = int(data["scratch_shape"][0])
95-
bin_blocks["scratch_data"] = b"\x00" * block_length
95+
bin_blocks["scratch_size"] = struct.pack("<I", block_length)
9696

9797
# Capture inputs and outputs
9898
bin_blocks["inputs"] = vela_bin_pack_io("input", data)

backends/arm/operator_support/to_copy_support.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
import copy
78
import logging
89

910
import torch
@@ -42,7 +43,9 @@ def _merge_supported_types(
4243
dtypes1: SupportedTypeDict,
4344
dtypes2: SupportedTypeDict,
4445
) -> SupportedTypeDict:
45-
merged_dtypes = dtypes1
46+
merged_dtypes = copy.deepcopy(
47+
dtypes1
48+
) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES
4649
for k, v in dtypes2.items():
4750
merged_dtypes[k] = merged_dtypes.get(k, []) + v
4851
return merged_dtypes

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ class CheckProperQuantization(OperatorSupportBase):
306306
exir_ops.edge.aten.sub.Tensor,
307307
exir_ops.edge.aten.upsample_bilinear2d.vec,
308308
exir_ops.edge.aten.upsample_nearest2d.vec,
309+
torch.ops.aten.scalar_tensor.default,
309310
*TableOps.included_ops(),
310311
)
311312

backends/arm/operators/op_abs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def define_node(
164164
scale_back = 1.0
165165
if inputs[0].dtype == ts.DType.INT8:
166166
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
167-
tosa_graph, inputs, node, self.tosa_specs
167+
tosa_graph, inputs, node, self.tosa_spec
168168
) # type: ignore[possibly-undefined]
169169
else:
170170
# input[0].dtype == ts.DType.INT32
@@ -192,7 +192,7 @@ def define_node(
192192
# Scale output back to 8 bit
193193
# pyre-ignore
194194
tqutils.insert_rescale_op_to_int8(
195-
tosa_graph, abs_output, scale_back, node, self.tosa_specs
195+
tosa_graph, abs_output, scale_back, node, self.tosa_spec
196196
) # type: ignore[possibly-undefined]
197197

198198

backends/arm/operators/op_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def define_node(
174174
scale_back = 1.0
175175
if inputs[0].dtype == ts.DType.INT8:
176176
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
177-
tosa_graph, inputs, node, self.tosa_specs
177+
tosa_graph, inputs, node, self.tosa_spec
178178
)
179179
else:
180180
# input[0].dtype == ts.DType.INT32
@@ -202,7 +202,7 @@ def define_node(
202202
# Scale output back to 8 bit
203203
# pyre-ignore
204204
tqutils.insert_rescale_op_to_int8(
205-
tosa_graph, add_output, scale_back, node, self.tosa_specs
205+
tosa_graph, add_output, scale_back, node, self.tosa_spec
206206
) # type: ignore[possibly-undefined]
207207

208208

backends/arm/operators/op_eq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def define_node(
9898
if inputs[0].dtype == ts.DType.INT8:
9999
# Rescale inputs to 32 bit
100100
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
101-
tosa_graph, inputs, node, self.tosa_specs
101+
tosa_graph, inputs, node, self.tosa_spec
102102
)
103103

104104
# Update IO

backends/arm/operators/op_ge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_gt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_le.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_lt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_maximum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def define_node(
129129
)
130130

131131
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
132-
tosa_graph, inputs, node, self.tosa_specs
132+
tosa_graph, inputs, node, self.tosa_spec
133133
)
134134

135135
output.shape = tosa_shape(output.shape, output.dim_order)
@@ -155,5 +155,5 @@ def define_node(
155155
if output.dtype == ts.DType.INT8:
156156
# insert RESCALE from int32 back to int8
157157
tqutils.insert_rescale_op_to_int8(
158-
tosa_graph, max_output, scale_back, node, self.tosa_specs
158+
tosa_graph, max_output, scale_back, node, self.tosa_spec
159159
)

backends/arm/operators/op_minimum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def define_node(
128128
)
129129

130130
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
131-
tosa_graph, inputs, node, self.tosa_specs
131+
tosa_graph, inputs, node, self.tosa_spec
132132
)
133133

134134
output.shape = tosa_shape(output.shape, output.dim_order)
@@ -154,5 +154,5 @@ def define_node(
154154
if output.dtype == ts.DType.INT8:
155155
# insert RESCALE from int32 back to int8
156156
tqutils.insert_rescale_op_to_int8(
157-
tosa_graph, min_output, scale_back, node, self.tosa_specs
157+
tosa_graph, min_output, scale_back, node, self.tosa_spec
158158
)

backends/arm/operators/op_mul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,14 @@ def define_node(
189189
input_A,
190190
input_A_qargs.zp,
191191
[1.0],
192-
tosa_spec=self.tosa_specs,
192+
tosa_spec=self.tosa_spec,
193193
)
194194
input_B_rescaled = tqutils.build_rescale_to_int32(
195195
tosa_graph,
196196
input_B,
197197
input_B_qargs.zp,
198198
[1.0],
199-
tosa_spec=self.tosa_specs,
199+
tosa_spec=self.tosa_spec,
200200
)
201201

202202
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
@@ -211,7 +211,7 @@ def define_node(
211211
)
212212
output_scale = input_A_qargs.scale * input_B_qargs.scale
213213
tqutils.insert_rescale_op_to_int8(
214-
tosa_graph, mul_output, output_scale, node, self.tosa_specs
214+
tosa_graph, mul_output, output_scale, node, self.tosa_spec
215215
)
216216

217217

backends/arm/operators/op_neg.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
NodeVisitor,
1717
register_node_visitor,
1818
)
19-
19+
from executorch.backends.arm.operators.operator_validation_utils import (
20+
validate_num_inputs,
21+
validate_same_dtype,
22+
)
2023
from executorch.backends.arm.tosa_mapping import TosaArg
2124

2225

@@ -60,14 +63,12 @@ def define_node(
6063
ts.DType.FP32,
6164
}
6265

66+
validate_num_inputs(self.target, inputs, 1)
67+
validate_same_dtype(self.target, [*inputs, output])
68+
6369
if inputs[0].dtype not in supported_dtypes:
6470
raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
6571

66-
if inputs[0].dtype != output.dtype:
67-
raise ValueError(
68-
"All inputs and output need same dtype."
69-
f"Got {inputs[0].dtype=}, {output.dtype=}"
70-
)
7172
input_zp, output_zp = get_negate_zero_points(
7273
node, inputs[0].dtype == ts.DType.INT8
7374
)
@@ -109,14 +110,12 @@ def define_node(
109110
ts.DType.FP32,
110111
}
111112

113+
validate_num_inputs(self.target, inputs, 1)
114+
validate_same_dtype(self.target, [*inputs, output])
115+
112116
if inputs[0].dtype not in supported_dtypes:
113117
raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
114118

115-
if inputs[0].dtype != output.dtype:
116-
raise ValueError(
117-
"All inputs and output need same dtype."
118-
f"Got {inputs[0].dtype=}, {output.dtype=}"
119-
)
120119
input_zp, output_zp = get_negate_zero_points(
121120
node, inputs[0].dtype == ts.DType.INT8
122121
)

backends/arm/operators/op_sub.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,16 @@ def define_node(
163163
validate_same_dtype(self.target, [*inputs, output])
164164

165165
# Handle int8 (quantized) and int32
166-
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
166+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
167+
if inputs[0].dtype not in supported_dtypes:
168+
raise TypeError(
169+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
170+
)
167171

168172
scale_back = 1.0
169173
if inputs[0].dtype == ts.DType.INT8:
170174
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
171-
tosa_graph, inputs, node, self.tosa_specs
175+
tosa_graph, inputs, node, self.tosa_spec
172176
)
173177
else:
174178
# input[0].dtype == ts.DType.INT32
@@ -197,7 +201,7 @@ def define_node(
197201
# Scale output back to 8 bit
198202
# pyre-ignore
199203
tqutils.insert_rescale_op_to_int8(
200-
tosa_graph, sub_output, scale_back, node, self.tosa_specs
204+
tosa_graph, sub_output, scale_back, node, self.tosa_spec
201205
) # type: ignore[possibly-undefined]
202206

203207

@@ -228,8 +232,15 @@ def define_node(
228232
super().define_node(node, tosa_graph, inputs, output)
229233
else:
230234
# FP32 Sub lowering
231-
assert inputs[0].dtype == ts.DType.FP32
232-
assert output.dtype == ts.DType.FP32
235+
if (
236+
inputs[0].dtype != ts.DType.FP32
237+
or inputs[1].dtype != ts.DType.FP32
238+
or output.dtype != ts.DType.FP32
239+
):
240+
raise TypeError(
241+
f"All IO needs to have data type fp32. Got: {inputs[0].dtype}, "
242+
f"input 2: {inputs[1].dtype} and output: {output.dtype}"
243+
)
233244

234245
# MI lowering
235246
tosa_graph.addOperator(

0 commit comments

Comments
 (0)