Skip to content

Commit d3c92de

Browse files
Remove memory-format workaround for Arm backend (#3981)
Summary: Remove temporary fix for memory format introduced in #2371. The dim-order of each node is annotated in a pass. Also some refactoring of arm_backend.py. Pull Request resolved: #3981 Reviewed By: kirklandsign Differential Revision: D59280091 Pulled By: digantdesai fbshipit-source-id: f591161830b7c0f836f0be9c33c12c1282f4cc4d
1 parent ac43606 commit d3c92de

19 files changed

+354
-329
lines changed

backends/arm/arm_backend.py

Lines changed: 6 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@
1616
import serializer.tosa_serializer as ts
1717
from executorch.backends.arm.arm_vela import vela_compile
1818
from executorch.backends.arm.operators.node_visitor import get_node_visitors
19+
from executorch.backends.arm.operators.op_output import process_output
1920
from executorch.backends.arm.operators.op_placeholder import process_placeholder
2021
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
21-
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
22-
from executorch.backends.arm.tosa_quant_utils import get_quant_node_dtype, is_quant_node
2322
from executorch.backends.arm.tosa_utils import (
2423
dbg_fail,
2524
dbg_tosa_dump,
26-
is_consumer_node_depthwise_conv2d,
27-
is_permute_node_before_addmm,
25+
process_call_function,
2826
)
2927
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3028
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -45,6 +43,7 @@ def __init__(self):
4543
self.compiler_flags = []
4644
self.output_format = None
4745
self.path_for_intermediates = None
46+
# TODO MLETORCH-265 Remove permute_nhwc flag
4847
self.permute_nhwc = False
4948
self.quantize_io = False
5049

@@ -217,18 +216,13 @@ def preprocess( # noqa: C901
217216
artifact_path = None
218217
output_format = ""
219218
compile_flags = []
220-
permute_memory_to_nhwc = False
221219
for spec in compile_spec:
222220
if spec.key == "debug_artifact_path":
223221
artifact_path = spec.value.decode()
224222
if spec.key == "output_format":
225223
output_format = spec.value.decode()
226224
if spec.key == "compile_flags":
227225
compile_flags.append(spec.value.decode())
228-
if spec.key == "permute_memory_format":
229-
memory_format = spec.value.decode()
230-
if memory_format == "nhwc":
231-
permute_memory_to_nhwc = True
232226

233227
# Check that the output format is set in the compile spec
234228
if not output_format:
@@ -250,76 +244,11 @@ def preprocess( # noqa: C901
250244

251245
for node in graph_module.graph.nodes:
252246
if node.op == "call_function":
253-
# Unpack arguments and convert
254-
inputs = []
255-
for arg in node.args:
256-
inputs.append(TosaArg(arg))
257-
258-
# Convert output (this node itself)
259-
output = TosaArg(node)
260-
261-
# TODO: fragile code for temporary fix, not all outputs will be
262-
# rank 4
263-
if permute_memory_to_nhwc and len(output.shape) == 4:
264-
# TODO: remove this if check
265-
# this is added because we need to align the quant node
266-
# output shape before the depthwise_conv2d node. The output
267-
# shape between TOSA conv2d and depthwise_conv2d are different.
268-
if (
269-
node.all_input_nodes[0].op
270-
== "placeholder" # check its parent is a placeholder
271-
and is_quant_node(node)
272-
and is_consumer_node_depthwise_conv2d(node)
273-
):
274-
NHWC_Order = [2, 3, 0, 1]
275-
else:
276-
NHWC_Order = [0, 2, 3, 1]
277-
output.shape = [output.shape[i] for i in NHWC_Order]
278-
279-
# Add output to TOSA graph
280-
tosa_graph.currRegion.currBasicBlock.addTensor(
281-
output.name,
282-
(
283-
inputs[0].shape
284-
if is_permute_node_before_addmm(node)
285-
else output.shape
286-
),
287-
(
288-
map_dtype(get_quant_node_dtype(node))
289-
if is_quant_node(node)
290-
else output.dtype
291-
),
292-
)
293-
294-
# Visiting each Node
295-
if node.target.__name__ in node_visitors:
296-
if node.target.__name__ in [
297-
"aten.add.Tensor",
298-
"aten._native_batch_norm_legit_no_training.default",
299-
]:
300-
node_visitors[node.target.__name__].define_node(
301-
node,
302-
tosa_graph,
303-
inputs,
304-
output,
305-
is_quant_node(node),
306-
permute_memory_to_nhwc,
307-
)
308-
else:
309-
node_visitors[node.target.__name__].define_node(
310-
node, tosa_graph, inputs, output, is_quant_node(node)
311-
)
312-
else:
313-
raise RuntimeError(f"Unknown operator {node.target}")
247+
process_call_function(node, tosa_graph, node_visitors)
314248
elif node.op == "placeholder":
315-
process_placeholder(
316-
node, tosa_graph, edge_program, permute_memory_to_nhwc
317-
)
249+
process_placeholder(node, tosa_graph, edge_program)
318250
elif node.op == "output":
319-
for output in node.args[0]:
320-
tosa_graph.addOutputTensor(
321-
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
322-
)
251+
process_output(node, tosa_graph)
323252
else:
324253
# This will only happen if an unpartitioned graph is passed without
325254
# any checking of compatibility.

backends/arm/operators/op_add.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
build_rescale_from_int32,
1717
build_rescale_to_int32,
1818
)
19-
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs
19+
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs, tosa_shape
2020
from serializer.tosa_serializer import TosaOp
2121

2222

@@ -34,7 +34,6 @@ def define_node(
3434
inputs: List[TosaArg],
3535
output: TosaArg,
3636
is_quant_node: bool,
37-
permute_memory_to_nhwc: bool,
3837
) -> None:
3938
if is_quant_node:
4039
# Single input or not
@@ -54,12 +53,9 @@ def define_node(
5453
inputA_rescale_scale = input_A_scale.number / min_scale
5554
inputB_rescale_scale = input_B_scale.number / min_scale
5655

56+
input_A.shape = tosa_shape(input_A.shape, input_A.dim_order)
57+
input_B.shape = tosa_shape(input_B.shape, input_B.dim_order)
5758
broadcasted_shape = broadcast_shapes(input_A.shape, input_B.shape)
58-
if permute_memory_to_nhwc:
59-
NHWC_Order = [0, 2, 3, 1]
60-
broadcasted_shape = [broadcasted_shape[i] for i in NHWC_Order]
61-
input_A.shape = [input_A.shape[i] for i in NHWC_Order]
62-
input_B.shape = [input_B.shape[i] for i in NHWC_Order]
6359

6460
input_A_rescaled_to_int32 = build_rescale_to_int32(
6561
tosa_graph,

backends/arm/operators/op_batch_norm.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14-
from executorch.backends.arm.tosa_utils import promote_shape
14+
from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape
1515
from serializer.tosa_serializer import TosaOp
1616

1717

@@ -25,12 +25,9 @@ def __init__(self, *args):
2525
# For BatchNorm2D, mean and var are calculated over the channel dimension
2626
# But TOSA doesn't allow subtraction of inputs with different ranks
2727
# Need to augment the shapes to match the ranks with activations
28-
def augment_shape_rank(self, input, permute_memory_to_nhwc):
29-
return (
30-
(1, 1, 1) + input.shape
31-
if permute_memory_to_nhwc
32-
else ((1,) + input.shape + (1, 1))
33-
)
28+
def augment_shape_rank(self, shape, dim_order):
29+
nchw_shape = (1, *shape, 1, 1)
30+
return tosa_shape(nchw_shape, dim_order)
3431

3532
def define_node(
3633
self,
@@ -39,7 +36,6 @@ def define_node(
3936
inputs: List[TosaArg],
4037
output: TosaArg,
4138
is_quant_node: bool,
42-
permute_memory_to_nhwc: bool,
4339
) -> None:
4440
# Decompose batch norm into sequence
4541
(activations, weights, bias, running_mean, running_var, momentum, epsilon) = (
@@ -67,13 +63,15 @@ def define_node(
6763
mean_reshaped = promote_shape(
6864
tosa_graph,
6965
running_mean,
70-
self.augment_shape_rank(running_mean, permute_memory_to_nhwc),
66+
self.augment_shape_rank(running_mean.shape, output.dim_order),
7167
input_dtype,
7268
)
7369

7470
# Subtract mean
7571
# %op1 = tosa.SUB(%activations, %running_mean)
76-
op1 = tosa_graph.addIntermediate(output.shape, input_dtype)
72+
op1 = tosa_graph.addIntermediate(
73+
tosa_shape(output.shape, output.dim_order), input_dtype
74+
)
7775
tosa_graph.addOperator(
7876
TosaOp.Op().SUB,
7977
[activations.name, mean_reshaped.name],
@@ -82,7 +80,9 @@ def define_node(
8280
# Adding eplison to variance
8381
# %op2 = tosa.ADD(%running_var, %epsilon_const)
8482
epsilon_const = tosa_graph.addConst([1], input_dtype, [epsilon.number])
85-
op2 = tosa_graph.addIntermediate(running_var.shape, input_dtype)
83+
op2 = tosa_graph.addIntermediate(
84+
tosa_shape(running_var.shape, running_var.dim_order), input_dtype
85+
)
8686
tosa_graph.addOperator(
8787
TosaOp.Op().ADD,
8888
[running_var.name, epsilon_const.name],
@@ -97,7 +97,7 @@ def define_node(
9797
op3_reshaped = promote_shape(
9898
tosa_graph,
9999
op3,
100-
self.augment_shape_rank(running_var, permute_memory_to_nhwc),
100+
self.augment_shape_rank(running_var.shape, output.dim_order),
101101
input_dtype,
102102
)
103103

@@ -114,7 +114,9 @@ def define_node(
114114
else:
115115
# Multiply shifted activations with reciprocal variance
116116
# %op4 = tosa.MUL(%op1, %op3)
117-
op4 = tosa_graph.addIntermediate(output.shape, input_dtype)
117+
op4 = tosa_graph.addIntermediate(
118+
tosa_shape(output.shape, output.dim_order), input_dtype
119+
)
118120
attr_mul = ts.TosaSerializerAttribute()
119121
attr_mul.MulAttribute(0)
120122
tosa_graph.addOperator(
@@ -130,7 +132,7 @@ def define_node(
130132
weights_reshaped = promote_shape(
131133
tosa_graph,
132134
weights,
133-
self.augment_shape_rank(weights, permute_memory_to_nhwc),
135+
self.augment_shape_rank(weights.shape, output.dim_order),
134136
input_dtype,
135137
)
136138

@@ -152,7 +154,7 @@ def define_node(
152154
bias_reshaped = promote_shape(
153155
tosa_graph,
154156
bias,
155-
self.augment_shape_rank(bias, permute_memory_to_nhwc),
157+
self.augment_shape_rank(bias.shape, output.dim_order),
156158
input_dtype,
157159
)
158160

@@ -170,12 +172,14 @@ def define_node(
170172
weights_reshaped = promote_shape(
171173
tosa_graph,
172174
weights,
173-
self.augment_shape_rank(weights, permute_memory_to_nhwc),
175+
self.augment_shape_rank(weights.shape, output.dim_order),
174176
input_dtype,
175177
)
176178

177179
# %op5 = tosa.MUL(%op4, %weights)
178-
op5 = tosa_graph.addIntermediate(output.shape, input_dtype)
180+
op5 = tosa_graph.addIntermediate(
181+
tosa_shape(output.shape, output.dim_order), input_dtype
182+
)
179183
attr_mul = ts.TosaSerializerAttribute()
180184
attr_mul.MulAttribute(0)
181185
tosa_graph.addOperator(
@@ -189,7 +193,7 @@ def define_node(
189193
bias_reshaped = promote_shape(
190194
tosa_graph,
191195
bias,
192-
self.augment_shape_rank(bias, permute_memory_to_nhwc),
196+
self.augment_shape_rank(bias.shape, output.dim_order),
193197
input_dtype,
194198
)
195199

backends/arm/operators/op_conv2d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
build_rescale_conv_output,
1616
get_quant_node_args,
1717
)
18-
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs
18+
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape
1919

2020
from serializer.tosa_serializer import TosaOp
2121

@@ -107,7 +107,9 @@ def define_node(
107107
# The output type is int32 when input type is int8.
108108
conv2d_output_name = output.name
109109
if is_quant_node:
110-
conv2d_res = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
110+
conv2d_res = tosa_graph.addIntermediate(
111+
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
112+
)
111113
conv2d_output_name = conv2d_res.name
112114

113115
# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)

backends/arm/operators/op_div.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -11,6 +11,7 @@
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14+
from executorch.backends.arm.tosa_utils import tosa_shape
1415
from serializer.tosa_serializer import TosaOp
1516

1617

@@ -30,7 +31,9 @@ def define_node(
3031
is_quant_node: bool,
3132
) -> None:
3233
# FP32 Div is implemented as output=x/y -> output=x*1/y e.g. MUL(x,RECIPROCAL(y))
33-
recip = tosa_graph.addIntermediate(inputs[1].shape, inputs[1].dtype)
34+
recip = tosa_graph.addIntermediate(
35+
tosa_shape(inputs[1].shape, inputs[1].dim_order), inputs[1].dtype
36+
)
3437
tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[1].name], [recip.name])
3538

3639
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_output.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import serializer.tosa_serializer as ts
7+
import torch
8+
9+
10+
def process_output(
11+
node: torch.fx.Node,
12+
tosa_graph: ts.TosaSerializer,
13+
):
14+
for output in node.args[0]:
15+
tosa_graph.addOutputTensor(
16+
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
17+
)

0 commit comments

Comments
 (0)