Skip to content

Commit d2089a0

Browse files
committed
Update on "[executorch][core] NamedDataMap interface"
Add NamedDataMap interface to runtime. Differential Revision: [D66834552](https://our.internmc.facebook.com/intern/diff/D66834552/) [ghstack-poisoned]
2 parents 93c0fb6 + 712f4cb commit d2089a0

File tree

101 files changed

+4030
-1063
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+4030
-1063
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mpmath==1.3.0
2-
numpy==2.0.0; python_version >= '3.10'
2+
numpy>=2.0.0; python_version >= '3.10'
33
PyYAML==6.0.1
44
ruamel.yaml==0.17.32
55
sympy==1.12
@@ -8,7 +8,7 @@ tomli==2.0.1
88
torchsr==1.0.4
99
transformers==4.47.1
1010
zstd==1.5.5.1
11-
pandas==2.2.2; python_version >= '3.10'
11+
pandas>=2.2.2; python_version >= '3.10'
1212
pytest==7.2.0
1313
pytest-cov==4.1.0
1414
expecttest==0.1.6
@@ -21,7 +21,7 @@ sphinx-gallery==0.14.0
2121
breathe==4.34.0
2222
exhale==0.2.3
2323
docutils==0.16
24-
matplotlib==3.9.4
24+
matplotlib>=3.9.4
2525
# PyTorch Theme
2626
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
2727
myst-parser==0.18.1

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_first_fake_tensor,
15-
get_node_arg,
1615
insert_q_dq_pair,
1716
)
1817
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
@@ -26,9 +25,8 @@
2625
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
2726
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
2827
lib = Library("passthrough_to_tosa", "DEF")
29-
# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need
30-
# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient
31-
# as we also need transpose the data into the correct data format.
28+
# For certain operators we need the data in a specific data format. Changing tosa_dim_order
29+
# is not sufficient as we also need transpose the data.
3230
# By utilizing an edge IR passthrough operator we can keep the edge program in
3331
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
3432
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")
@@ -153,27 +151,6 @@ def insert_output_transpose(node, graph_module):
153151
q_params = node.args[0].args[1:]
154152
insert_q_dq_pair(graph_module.graph, node, q_params)
155153

156-
@staticmethod
157-
def _insert_squeeze_transpose(
158-
input_shape, output_shape, node, input_node, graph_module
159-
):
160-
nhwc_to_nhwc = len(input_shape) == 4 and len(output_shape) <= 3
161-
162-
if nhwc_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
163-
input_shape
164-
):
165-
AnnotateChannelsLastDimOrder.insert_input_transpose(
166-
node, input_node, graph_module
167-
)
168-
169-
@staticmethod
170-
def _insert_unsqueeze_transpose(input_shape, output_shape, node, graph_module):
171-
nchw_to_nhwc = len(input_shape) == 3 and len(output_shape) == 4
172-
if nchw_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
173-
output_shape
174-
):
175-
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
176-
177154
@staticmethod
178155
def _insert_view_transpose(
179156
input_shape, output_shape, node, input_node, graph_module
@@ -199,8 +176,6 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
199176
"""
200177
Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
201178
This is relevant for the following cases:
202-
- squeeze: 4D -> <4D
203-
- unsqueeze: 3D -> 4D
204179
- view: <4D -> 4D
205180
- view: 4D -> <4D
206181
Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
@@ -214,27 +189,6 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
214189
if node.op != "call_function":
215190
continue
216191

217-
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
218-
input_node = node.args[0]
219-
input_shape = input_node.meta["val"].shape
220-
output_shape = node.meta["val"].shape
221-
222-
self._insert_squeeze_transpose(
223-
input_shape, output_shape, node, input_node, graph_module
224-
)
225-
226-
elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
227-
input_node = get_node_arg(node.args, 0, default_value=False)
228-
if input_node:
229-
input_shape = input_node.meta["val"].shape
230-
else:
231-
input_shape = ()
232-
output_shape = node.meta["val"].shape
233-
234-
self._insert_unsqueeze_transpose(
235-
input_shape, output_shape, node, graph_module
236-
)
237-
238192
elif node.target == exir_ops.edge.aten.view_copy.default:
239193
input_node = node.args[0]
240194
input_shape = input_node.meta["val"].shape

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from executorch.backends.arm._passes.convert_split_to_slice import (
2222
ConvertSplitToSlicePass,
2323
)
24+
from executorch.backends.arm._passes.convert_squeezes_to_view import (
25+
ConvertSqueezesToViewPass,
26+
)
2427
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
2528
from executorch.backends.arm._passes.decompose_layernorm_pass import (
2629
DecomposeLayerNormPass,
@@ -100,6 +103,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
100103
self.add_pass(KeepDimsFalseToSqueezePass())
101104
self.add_pass(Conv1dUnsqueezePass(exported_program))
102105
self.add_pass(DecomposeSelectPass())
106+
self.add_pass(ConvertSqueezesToViewPass())
103107

104108
self.add_pass(AnnotateChannelsLastDimOrder())
105109

@@ -135,6 +139,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
135139
self.add_pass(KeepDimsFalseToSqueezePass())
136140
self.add_pass(Conv1dUnsqueezePass(exported_program))
137141
self.add_pass(DecomposeSelectPass())
142+
self.add_pass(ConvertSqueezesToViewPass())
138143

139144
self.add_pass(AnnotateChannelsLastDimOrder())
140145

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class ConvertSqueezesToViewPass(ExportPass):
14+
"""
15+
Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors.
16+
"""
17+
18+
def call_operator(self, op, args, kwargs, meta):
19+
if op not in [
20+
exir_ops.edge.aten.squeeze_copy.dims,
21+
exir_ops.edge.aten.unsqueeze_copy.default,
22+
]:
23+
return super().call_operator(op, args, kwargs, meta)
24+
25+
x = args[0]
26+
shape = meta["val"].size()
27+
view_args = (x, list(shape))
28+
return super().call_operator(
29+
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
30+
)

backends/arm/operator_support/to_copy_support.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -22,7 +22,10 @@
2222

2323
@register_tosa_support_check
2424
class ToCopySupported(SupportedTOSAOperatorCheck):
25-
targets = [exir_ops.edge.aten._to_copy.default]
25+
targets = [
26+
exir_ops.edge.aten._to_copy.default,
27+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
28+
]
2629

2730
tosa_specs = [
2831
TosaSpecification.create_from_string("TOSA-0.80+BI"),
@@ -110,7 +113,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
110113
)
111114
return False
112115

113-
# Check memory format
116+
# Check memory format (to_copy)
114117
if "memory_format" in node.kwargs:
115118
if node.kwargs["memory_format"] in (torch.preserve_format,):
116119
logger.info(
@@ -119,4 +122,14 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
119122
)
120123
return False
121124

125+
# Check dim_order (to_dim_order_copy)
126+
if "dim_order" in node.kwargs:
127+
dim_order = node.kwargs["dim_order"]
128+
if dim_order != list(range(len(dim_order))):
129+
logger.info(
130+
f"Argument {dim_order=} is not supported for "
131+
f"{node.target.name()} right now." # pyre-ignore[16]
132+
)
133+
return False
134+
122135
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -81,11 +81,16 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
8181
exir_ops.edge.aten.hardtanh.default,
8282
exir_ops.edge.aten.convolution.default,
8383
exir_ops.edge.aten.div.Tensor,
84+
exir_ops.edge.aten.eq.Tensor,
8485
exir_ops.edge.aten.exp.default,
8586
exir_ops.edge.aten.log.default,
8687
exir_ops.edge.aten.linear.default,
8788
exir_ops.edge.aten.split_with_sizes_copy.default,
8889
exir_ops.edge.aten.full.default,
90+
exir_ops.edge.aten.ge.Tensor,
91+
exir_ops.edge.aten.gt.Tensor,
92+
exir_ops.edge.aten.le.Tensor,
93+
exir_ops.edge.aten.lt.Tensor,
8994
exir_ops.edge.aten.mul.Tensor,
9095
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
9196
exir_ops.edge.aten.native_layer_norm.default,

backends/arm/operators/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313
op_bmm,
1414
op_cat,
1515
op_conv2d,
16+
op_eq,
1617
op_exp,
1718
op_full,
19+
op_ge,
1820
op_get_item,
21+
op_gt,
1922
op_hardtanh,
23+
op_le,
2024
op_log,
25+
op_lt,
2126
op_max,
2227
op_max_pool2d,
2328
op_min,
@@ -30,14 +35,13 @@
3035
op_rsqrt,
3136
op_sigmoid,
3237
op_slice,
33-
op_squeeze,
3438
op_sub,
3539
op_sum,
3640
op_table,
3741
op_tanh,
3842
op_to_copy,
43+
op_to_dim_order_copy,
3944
op_transpose,
40-
op_unsqueeze,
4145
op_upsample_nearest2d,
4246
op_view,
4347
)

backends/arm/operators/op_eq.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
12+
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from serializer.tosa_serializer import TosaOp
19+
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class EqualVisitor(NodeVisitor):
25+
target = "aten.eq.Tensor"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
assert (
38+
inputs[0].dtype == inputs[1].dtype
39+
), "EQ must have the same dtypes as input"
40+
41+
input_nodes = inputs
42+
# Handle quantization
43+
if inputs[0].dtype == ts.DType.INT8:
44+
# Rescale inputs to 32 bit
45+
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
46+
tosa_graph, inputs, node
47+
)
48+
49+
# Update IO
50+
input_nodes = rescaled_inputs
51+
52+
# Do the equal comparison
53+
tosa_graph.addOperator(
54+
TosaOp.Op().EQUAL,
55+
[input_nodes[0].name, input_nodes[1].name],
56+
output.name,
57+
None,
58+
)

backends/arm/operators/op_ge.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
12+
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from serializer.tosa_serializer import TosaOp
19+
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class GreaterEqualVisitor(NodeVisitor):
25+
target = "aten.ge.Tensor"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
assert (
38+
inputs[0].dtype == inputs[1].dtype
39+
), "GE must have the same dtypes as input"
40+
41+
input_nodes = inputs
42+
# Handle quantization
43+
if inputs[0].dtype == ts.DType.INT8:
44+
# Rescale inputs to 32 bit
45+
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
46+
tosa_graph, inputs, node
47+
)
48+
49+
# Update IO
50+
input_nodes = rescaled_inputs
51+
52+
tosa_graph.addOperator(
53+
TosaOp.Op().GREATER_EQUAL,
54+
[input_nodes[0].name, input_nodes[1].name],
55+
[output.name],
56+
None,
57+
)

0 commit comments

Comments
 (0)