Skip to content

Commit c1eac04

Browse files
Add pass to annotate dim-order
The dim-order of each node is annotated in a pass. Some refactoring of arm_backend.py. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I14691b51b99acb9e8605100fd25731ab45c55a9d
1 parent 089858b commit c1eac04

19 files changed

+355
-285
lines changed

backends/arm/arm_backend.py

Lines changed: 11 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
import serializer.tosa_serializer as ts
1717
from executorch.backends.arm.arm_vela import vela_compile
1818
from executorch.backends.arm.operators.node_visitor import get_node_visitors
19+
from executorch.backends.arm.operators.op_output import process_output
1920
from executorch.backends.arm.operators.op_placeholder import process_placeholder
20-
from executorch.backends.arm.tosa_mapping import TosaArg
21-
from executorch.backends.arm.tosa_quant_utils import is_quant_node
21+
from executorch.backends.arm.passes.permute_memory_pass import PermuteMemoryPass
2222
from executorch.backends.arm.tosa_utils import (
2323
dbg_fail,
2424
dbg_tosa_dump,
25-
is_consumer_node_depthwise_conv2d,
26-
is_permute_node_before_addmm,
25+
process_call_function,
2726
)
2827
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2928
from executorch.exir.backend.compile_spec_schema import CompileSpec
29+
from executorch.exir.pass_manager import PassManager
3030
from torch.export.exported_program import ExportedProgram
3131

3232
# TOSA backend debug functionality
@@ -243,77 +243,20 @@ def preprocess( # noqa: C901
243243
# Converted output for this subgraph, serializer needs path early as it emits
244244
# const data directly. Path created and data written only in debug builds.
245245
tosa_graph = ts.TosaSerializer(path)
246+
passes = PassManager()
247+
if permute_memory_to_nhwc:
248+
passes.add_pass(PermuteMemoryPass(edge_program))
249+
passes(edge_program.graph_module)
246250

247251
node_visitors = get_node_visitors(edge_program)
248252

249253
for node in edge_program.graph.nodes:
250254
if node.op == "call_function":
251-
# Unpack arguments and convert
252-
inputs = []
253-
for arg in node.args:
254-
inputs.append(TosaArg(arg))
255-
256-
# Convert output (this node itself)
257-
output = TosaArg(node)
258-
259-
# TODO: fragile code for temporary fix, not all outputs will be
260-
# rank 4
261-
if permute_memory_to_nhwc and len(output.shape) == 4:
262-
# TODO: remove this if check
263-
# this is added because we need to align the quant node
264-
# output shape before the depthwise_conv2d node. The output
265-
# shape between TOSA conv2d and depthwise_conv2d are different.
266-
if (
267-
node.all_input_nodes[0].op
268-
== "placeholder" # check its parent is a placeholder
269-
and is_quant_node(node)
270-
and is_consumer_node_depthwise_conv2d(node)
271-
):
272-
NHWC_Order = [2, 3, 0, 1]
273-
else:
274-
NHWC_Order = [0, 2, 3, 1]
275-
output.shape = [output.shape[i] for i in NHWC_Order]
276-
277-
# Add output to TOSA graph
278-
tosa_graph.currRegion.currBasicBlock.addTensor(
279-
output.name,
280-
(
281-
inputs[0].shape
282-
if is_permute_node_before_addmm(node)
283-
else output.shape
284-
),
285-
ts.DType.INT8 if is_quant_node(node) else output.dtype,
286-
)
287-
288-
# Visiting each Node
289-
if node.target.__name__ in node_visitors:
290-
if node.target.__name__ in [
291-
"aten.add.Tensor",
292-
"aten._native_batch_norm_legit_no_training.default",
293-
]:
294-
node_visitors[node.target.__name__].define_node(
295-
node,
296-
tosa_graph,
297-
inputs,
298-
output,
299-
is_quant_node(node),
300-
permute_memory_to_nhwc,
301-
)
302-
else:
303-
node_visitors[node.target.__name__].define_node(
304-
node, tosa_graph, inputs, output, is_quant_node(node)
305-
)
306-
else:
307-
raise RuntimeError(f"Unknown operator {node.target}")
255+
process_call_function(node, tosa_graph, node_visitors)
308256
elif node.op == "placeholder":
309-
process_placeholder(
310-
node, tosa_graph, edge_program, permute_memory_to_nhwc
311-
)
257+
process_placeholder(node, tosa_graph, edge_program)
312258
elif node.op == "output":
313-
for output in node.args[0]:
314-
tosa_graph.addOutputTensor(
315-
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
316-
)
259+
process_output(node, tosa_graph)
317260
else:
318261
# This will only happen if an unpartitioned graph is passed without
319262
# any checking of compatibility.

backends/arm/operators/op_add.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
build_rescale_from_int32,
1717
build_rescale_to_int32,
1818
)
19-
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs
19+
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs, tosa_shape
2020
from serializer.tosa_serializer import TosaOp
2121

2222

@@ -34,7 +34,6 @@ def define_node(
3434
inputs: List[TosaArg],
3535
output: TosaArg,
3636
is_quant_node: bool,
37-
permute_memory_to_nhwc: bool,
3837
) -> None:
3938
if is_quant_node:
4039
# Single input or not
@@ -54,12 +53,9 @@ def define_node(
5453
inputA_rescale_scale = input_A_scale.number / min_scale
5554
inputB_rescale_scale = input_B_scale.number / min_scale
5655

56+
input_A.shape = tosa_shape(input_A.shape, input_A.dim_order)
57+
input_B.shape = tosa_shape(input_B.shape, input_B.dim_order)
5758
broadcasted_shape = broadcast_shapes(input_A.shape, input_B.shape)
58-
if permute_memory_to_nhwc:
59-
NHWC_Order = [0, 2, 3, 1]
60-
broadcasted_shape = [broadcasted_shape[i] for i in NHWC_Order]
61-
input_A.shape = [input_A.shape[i] for i in NHWC_Order]
62-
input_B.shape = [input_B.shape[i] for i in NHWC_Order]
6359

6460
input_A_rescaled_to_int32 = build_rescale_to_int32(
6561
tosa_graph,

backends/arm/operators/op_batch_norm.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14-
from executorch.backends.arm.tosa_utils import promote_shape
14+
from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape
1515
from serializer.tosa_serializer import TosaOp
1616

1717

@@ -25,12 +25,9 @@ def __init__(self, *args):
2525
# For BatchNorm2D, mean and var are calculated over the channel dimension
2626
# But TOSA doesn't allow subtraction of inputs with different ranks
2727
# Need to augment the shapes to match the ranks with activations
28-
def augment_shape_rank(self, input, permute_memory_to_nhwc):
29-
return (
30-
(1, 1, 1) + input.shape
31-
if permute_memory_to_nhwc
32-
else ((1,) + input.shape + (1, 1))
33-
)
28+
def augment_shape_rank(self, shape, dim_order):
29+
nchw_shape = (1, *shape, 1, 1)
30+
return tosa_shape(nchw_shape, dim_order)
3431

3532
def define_node(
3633
self,
@@ -39,7 +36,6 @@ def define_node(
3936
inputs: List[TosaArg],
4037
output: TosaArg,
4138
is_quant_node: bool,
42-
permute_memory_to_nhwc: bool,
4339
) -> None:
4440
# Decompose batch norm into sequence
4541
(activations, weights, bias, running_mean, running_var, momentum, epsilon) = (
@@ -67,13 +63,15 @@ def define_node(
6763
mean_reshaped = promote_shape(
6864
tosa_graph,
6965
running_mean,
70-
self.augment_shape_rank(running_mean, permute_memory_to_nhwc),
66+
self.augment_shape_rank(running_mean.shape, output.dim_order),
7167
input_dtype,
7268
)
7369

7470
# Subtract mean
7571
# %op1 = tosa.SUB(%activations, %running_mean)
76-
op1 = tosa_graph.addIntermediate(output.shape, input_dtype)
72+
op1 = tosa_graph.addIntermediate(
73+
tosa_shape(output.shape, output.dim_order), input_dtype
74+
)
7775
tosa_graph.addOperator(
7876
TosaOp.Op().SUB,
7977
[activations.name, mean_reshaped.name],
@@ -82,7 +80,9 @@ def define_node(
8280
# Adding eplison to variance
8381
# %op2 = tosa.ADD(%running_var, %epsilon_const)
8482
epsilon_const = tosa_graph.addConst([1], input_dtype, [epsilon.number])
85-
op2 = tosa_graph.addIntermediate(running_var.shape, input_dtype)
83+
op2 = tosa_graph.addIntermediate(
84+
tosa_shape(running_var.shape, running_var.dim_order), input_dtype
85+
)
8686
tosa_graph.addOperator(
8787
TosaOp.Op().ADD,
8888
[running_var.name, epsilon_const.name],
@@ -97,7 +97,7 @@ def define_node(
9797
op3_reshaped = promote_shape(
9898
tosa_graph,
9999
op3,
100-
self.augment_shape_rank(running_var, permute_memory_to_nhwc),
100+
self.augment_shape_rank(running_var.shape, output.dim_order),
101101
input_dtype,
102102
)
103103

@@ -114,7 +114,9 @@ def define_node(
114114
else:
115115
# Multiply shifted activations with reciprocal variance
116116
# %op4 = tosa.MUL(%op1, %op3)
117-
op4 = tosa_graph.addIntermediate(output.shape, input_dtype)
117+
op4 = tosa_graph.addIntermediate(
118+
tosa_shape(output.shape, output.dim_order), input_dtype
119+
)
118120
attr_mul = ts.TosaSerializerAttribute()
119121
attr_mul.MulAttribute(0)
120122
tosa_graph.addOperator(
@@ -130,7 +132,7 @@ def define_node(
130132
weights_reshaped = promote_shape(
131133
tosa_graph,
132134
weights,
133-
self.augment_shape_rank(weights, permute_memory_to_nhwc),
135+
self.augment_shape_rank(weights.shape, output.dim_order),
134136
input_dtype,
135137
)
136138

@@ -152,7 +154,7 @@ def define_node(
152154
bias_reshaped = promote_shape(
153155
tosa_graph,
154156
bias,
155-
self.augment_shape_rank(bias, permute_memory_to_nhwc),
157+
self.augment_shape_rank(bias.shape, output.dim_order),
156158
input_dtype,
157159
)
158160

@@ -170,12 +172,14 @@ def define_node(
170172
weights_reshaped = promote_shape(
171173
tosa_graph,
172174
weights,
173-
self.augment_shape_rank(weights, permute_memory_to_nhwc),
175+
self.augment_shape_rank(weights.shape, output.dim_order),
174176
input_dtype,
175177
)
176178

177179
# %op5 = tosa.MUL(%op4, %weights)
178-
op5 = tosa_graph.addIntermediate(output.shape, input_dtype)
180+
op5 = tosa_graph.addIntermediate(
181+
tosa_shape(output.shape, output.dim_order), input_dtype
182+
)
179183
attr_mul = ts.TosaSerializerAttribute()
180184
attr_mul.MulAttribute(0)
181185
tosa_graph.addOperator(
@@ -189,7 +193,7 @@ def define_node(
189193
bias_reshaped = promote_shape(
190194
tosa_graph,
191195
bias,
192-
self.augment_shape_rank(bias, permute_memory_to_nhwc),
196+
self.augment_shape_rank(bias.shape, output.dim_order),
193197
input_dtype,
194198
)
195199

backends/arm/operators/op_conv2d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
build_rescale_conv_output,
1616
get_quant_node_args,
1717
)
18-
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs
18+
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape
1919

2020
from serializer.tosa_serializer import TosaOp
2121

@@ -107,7 +107,9 @@ def define_node(
107107
# The output type is int32 when input type is int8.
108108
conv2d_output_name = output.name
109109
if is_quant_node:
110-
conv2d_res = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
110+
conv2d_res = tosa_graph.addIntermediate(
111+
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
112+
)
111113
conv2d_output_name = conv2d_res.name
112114

113115
# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)

backends/arm/operators/op_div.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -11,6 +11,7 @@
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14+
from executorch.backends.arm.tosa_utils import tosa_shape
1415
from serializer.tosa_serializer import TosaOp
1516

1617

@@ -30,7 +31,9 @@ def define_node(
3031
is_quant_node: bool,
3132
) -> None:
3233
# FP32 Div is implemented as output=x/y -> output=x*1/y e.g. MUL(x,RECIPROCAL(y))
33-
recip = tosa_graph.addIntermediate(inputs[1].shape, inputs[1].dtype)
34+
recip = tosa_graph.addIntermediate(
35+
tosa_shape(inputs[1].shape, inputs[1].dim_order), inputs[1].dtype
36+
)
3437
tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[1].name], [recip.name])
3538

3639
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_output.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import serializer.tosa_serializer as ts
7+
import torch
8+
9+
10+
def process_output(
11+
node: torch.fx.Node,
12+
tosa_graph: ts.TosaSerializer,
13+
):
14+
for output in node.args[0]:
15+
tosa_graph.addOutputTensor(
16+
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
17+
)

0 commit comments

Comments
 (0)