Skip to content

Commit b69d87a

Browse files
Arm backend: Refactor is_consumer_node_dw_conv2d
Refactor is_consumer_node_depthwise_conv2d to not use TosaArg. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I8a556ce290242b0a660f56a8f4048cba806817af
1 parent cd3b53d commit b69d87a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

backends/arm/tosa_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,14 @@ def get_new_shape(l_rank_in, h_rank_in):
153153
return reshaped, input2
154154

155155

156-
def is_consumer_node_depthwise_conv2d(node):
156+
def is_consumer_node_depthwise_conv2d(node: Node):
157157
consumer_node = list(node.users)[0]
158158
if consumer_node.target == exir_ops.edge.aten.convolution.default:
159-
inputs = getNodeArgs(consumer_node)
160-
group = inputs[-1]
161-
in_channels = inputs[0].shape[1]
162-
out_channels = inputs[1].shape[0]
163-
if (in_channels == group.number) and (out_channels % in_channels) == 0:
159+
consumer_node_inputs = consumer_node.all_input_nodes
160+
groups = consumer_node.args[-1]
161+
in_channels = consumer_node_inputs[0].meta["val"].shape[1]
162+
out_channels = consumer_node_inputs[1].meta["val"].shape[0]
163+
if (in_channels == groups) and (out_channels % in_channels) == 0:
164164
return True
165165

166166
return False

0 commit comments

Comments
 (0)