Skip to content

Commit 0cd6a43

Browse files
Fixes for permute/annotation-pass
Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ica6addb95d6b925beef4696780334268821af608
1 parent 372e36b commit 0cd6a43

File tree

7 files changed

+82
-123
lines changed

7 files changed

+82
-123
lines changed

backends/arm/arm_backend.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from executorch.backends.arm.operators.node_visitor import get_node_visitors
1919
from executorch.backends.arm.operators.op_output import process_output
2020
from executorch.backends.arm.operators.op_placeholder import process_placeholder
21-
from executorch.backends.arm.passes.permute_memory_pass import PermuteMemoryPass
21+
from executorch.backends.arm.passes.annotate_channels_last_dim_order_pass import (
22+
AnnotateChannelsLastDimOrder,
23+
)
2224
from executorch.backends.arm.tosa_utils import (
2325
dbg_fail,
2426
dbg_tosa_dump,
@@ -44,6 +46,7 @@ def __init__(self):
4446
self.compiler_flags = []
4547
self.output_format = None
4648
self.path_for_intermediates = None
49+
# TODO MLETORCH-265 Remove permute_nhwc flag
4750
self.permute_nhwc = False
4851
self.quantize_io = False
4952

@@ -245,7 +248,7 @@ def preprocess( # noqa: C901
245248
tosa_graph = ts.TosaSerializer(path)
246249
passes = PassManager()
247250
if permute_memory_to_nhwc:
248-
passes.add_pass(PermuteMemoryPass(edge_program))
251+
passes.add_pass(AnnotateChannelsLastDimOrder())
249252
passes(edge_program.graph_module)
250253

251254
node_visitors = get_node_visitors(edge_program)

backends/arm/operators/op_placeholder.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def process_inputs(
2121
node: torch.fx.Node,
2222
tosa_graph: ts.TosaSerializer,
2323
):
24+
"""Serialize an input node"""
2425
inputs = [TosaArg(node)]
2526
input_shape = inputs[0].shape
2627
input_dim_order = inputs[0].dim_order
@@ -39,6 +40,10 @@ def process_quantized_bias(
3940
tosa_graph: ts.TosaSerializer,
4041
parameter_values,
4142
):
43+
"""
44+
Serialize bias node that needs to be quantized.
45+
This can be either an addmm or conv bias node.
46+
"""
4247
consumer_node = list(node.users)[0]
4348
if is_bias_node_for_quantized_addmm(node):
4449
(
@@ -73,17 +78,12 @@ def process_quantized_bias(
7378
)
7479

7580

76-
def permute(data, dim_order):
77-
if len(data.shape) == 4:
78-
data = np.transpose(data, dim_order)
79-
return data
80-
81-
8281
def process_inputs_to_parameters(
8382
node: torch.fx.Node,
8483
tosa_graph: ts.TosaSerializer,
8584
edge_program: ExportedProgram,
8685
):
86+
"""Serialize bias and non-quantized weights"""
8787
inputs = [TosaArg(node)]
8888
parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name]
8989
parameter_data = edge_program.state_dict[parameter_name]
@@ -92,17 +92,11 @@ def process_inputs_to_parameters(
9292
parameter_values = parameter_data.detach().numpy()
9393

9494
if is_bias_node_for_quantized_addmm(node) or is_bias_node_for_quantized_conv(node):
95+
# BI bias
9596
process_quantized_bias(node, tosa_graph, parameter_values)
9697
else:
97-
# Cases for:
98-
# - MI_AddMM_bias
99-
# - MI_AddMM_weight
100-
# - MI_Conv2d_non_bias_weight
101-
# - MI_Conv2d_weight
102-
# - MI_Conv2d_bias
103-
# - MI_DepthwiseConv2d_weight
104-
# - MI_DepthwiseConv2d_bias
105-
parameter_values = permute(parameter_values, inputs[0].dim_order)
98+
# MI weights or bias
99+
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
106100

107101
tosa_graph.addConst(
108102
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
@@ -114,6 +108,7 @@ def process_inputs_to_buffers(
114108
tosa_graph: ts.TosaSerializer,
115109
edge_program: ExportedProgram,
116110
):
111+
"""Serialize quantized weights"""
117112
inputs = [TosaArg(node)]
118113
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
119114
buffer_data = edge_program.state_dict[buffer_name]
@@ -124,7 +119,7 @@ def process_inputs_to_buffers(
124119
# TODO: fragile code for temporary fix
125120
# the mean and var tensors are also stored here but they have shape (1, )
126121
# we only transpose weights here
127-
buffer_values = permute(buffer_values, inputs[0].dim_order)
122+
buffer_values = np.transpose(buffer_values, inputs[0].dim_order)
128123

129124
tosa_graph.addConst(
130125
buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name
@@ -136,6 +131,7 @@ def process_placeholder(
136131
tosa_graph: ts.TosaSerializer,
137132
edge_program: ExportedProgram,
138133
):
134+
"""Wrapper for processing and serializing all types of placeholders"""
139135
assert node.name == node.target, "Expect placeholder name and target to match"
140136
assert 0 == len(node.args), "Can't handle default input values"
141137

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.arm.tosa_quant_utils import dq_op
9+
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
13+
class AnnotateChannelsLastDimOrder(ExportPass):
14+
"""
15+
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
16+
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes.
17+
The annotated tosa_dim_order is used to permute the node's shape such that it
18+
gives a TOSA-compliant shape.
19+
"""
20+
21+
def is_weight_node_for_dw_conv(self, node: torch.fx.Node):
22+
"""
23+
returns True for dq and w in the following sequences;
24+
w -> dw_conv -> ...
25+
w -> dq -> dw_conv -> ...
26+
"""
27+
if node.op == "call_function":
28+
if node.target != dq_op:
29+
return False
30+
prev_node = node.args[0]
31+
if prev_node.op != "placeholder":
32+
return False
33+
return is_consumer_node_depthwise_conv2d(node)
34+
elif node.op == "placeholder":
35+
# node is an input, weight or bias node
36+
consumer_node = list(node.users)[0]
37+
if self.is_weight_node_for_dw_conv(consumer_node):
38+
return True
39+
if is_consumer_node_depthwise_conv2d(node):
40+
# Check that node is the weight-argument and not input or bias
41+
return consumer_node.args[1] == node
42+
43+
return False
44+
45+
def call(self, graph_module: torch.fx.GraphModule):
46+
NHWC_Order = (0, 2, 3, 1)
47+
HWCM_Order = (2, 3, 0, 1)
48+
for node in graph_module.graph.nodes:
49+
if isinstance(node.meta["val"], tuple):
50+
node_data = node.meta["val"][0].data
51+
else:
52+
node_data = node.meta["val"].data
53+
54+
if len(node_data.shape) == 4:
55+
dim_order = NHWC_Order
56+
if self.is_weight_node_for_dw_conv(node):
57+
dim_order = HWCM_Order
58+
else:
59+
dim_order = tuple(range(node_data.dim()))
60+
node.meta["tosa_dim_order"] = dim_order
61+
graph_module.recompile()
62+
return PassResult(graph_module, True)

backends/arm/passes/arm_pass.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

backends/arm/passes/permute_memory_pass.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

backends/arm/test/ops/test_conv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def forward(self, x):
114114
return x
115115

116116

117-
conv2d_2x2_3x1x40x40_nobias = Conv2d(
117+
conv2d_2x2_3x2x40x40_nobias = Conv2d(
118118
in_channels=2,
119119
out_channels=3,
120120
kernel_size=(2, 2),
@@ -221,7 +221,7 @@ def forward(self, x):
221221
# Shenanigan to get a nicer output when test fails. With unittest it looks like:
222222
# FAIL: test_conv2d_tosa_BI_2_3x3_1x3x12x12_st2_pd1
223223
testsuite = [
224-
("2x2_3x1x40x40_nobias", conv2d_2x2_3x1x40x40_nobias),
224+
("2x2_3x2x40x40_nobias", conv2d_2x2_3x2x40x40_nobias),
225225
("3x3_1x3x256x256_st1", conv2d_3x3_1x3x256x256_st1),
226226
("3x3_1x3x12x12_st2_pd1", conv2d_3x3_1x3x12x12_st2_pd1),
227227
("1x1_1x2x128x128_st1", conv2d_1x1_1x2x128x128_st1),
@@ -236,7 +236,7 @@ def forward(self, x):
236236
# Check: https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/heads/main/SUPPORTED_OPS.md
237237
# IFM Tensor batch size must be 1 - [FULLY_CONNECTED, RESHAPE, SHAPE, SLICE, SOFTMAX, SPLIT, SPLIT_V, SQUEEZE, STRIDED_SLICE, UNPACK]
238238
testsuite_u55 = testsuite.copy()
239-
testsuite_u55.remove(("2x2_3x1x40x40_nobias", conv2d_2x2_3x1x40x40_nobias))
239+
testsuite_u55.remove(("2x2_3x2x40x40_nobias", conv2d_2x2_3x2x40x40_nobias))
240240
testsuite_u55.remove(("5x5_3x2x128x128_st1", conv2d_5x5_3x2x128x128_st1))
241241

242242

backends/arm/tosa_utils.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -87,33 +87,6 @@ def promote_shape(tosa_fb, arg, promoted_shape, out_dtype):
8787
return reshape_res
8888

8989

90-
# Helper transpose function to match TOSA's shape requirements
91-
# E.g., TOSA 0.80.0 specification - 2.3.3 CONV2D shapes:
92-
# https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d
93-
def transpose_helper(tosa_fb, input, new_order, out_dtype):
94-
# Check new_order's length is equal to input rank
95-
assert len(input.shape) == len(new_order), "Wrong shape order length"
96-
97-
# Check no duplications
98-
assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers"
99-
100-
# Check all dims are valid
101-
for idx in new_order:
102-
if idx < 0:
103-
assert True, "Negative dim number"
104-
elif idx >= len(input.shape):
105-
assert True, "Dim is greater than input rank"
106-
107-
input_shape_transpoed = [input.shape[i] for i in new_order]
108-
attr = ts.TosaSerializerAttribute()
109-
attr.TransposeAttribute(new_order)
110-
input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype)
111-
tosa_fb.addOperator(
112-
TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr
113-
)
114-
return input_transposed
115-
116-
11790
def getNodeArgs(node):
11891
return [TosaArg(arg) for arg in node.args]
11992

0 commit comments

Comments
 (0)