Skip to content

Commit 3066463

Browse files
Arm backend: Refactor Conv1dUnsqueezePass (#11144)
Simplifies the pass structure and removes the need for using the exported_program. Signed-off-by: Adrian Lundell <[email protected]>
1 parent d147a2c commit 3066463

File tree

2 files changed

+44
-125
lines changed

2 files changed

+44
-125
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
118118
self.add_pass(UnsqueezeBeforeRepeatPass())
119119
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
120120
self.add_pass(DecomposeSumPass())
121-
self.add_pass(Conv1dUnsqueezePass(exported_program))
121+
self.add_pass(Conv1dUnsqueezePass())
122122
self.add_pass(DecomposeSelectPass())
123123
self.add_pass(ConvertSqueezesToViewPass())
124124

@@ -173,7 +173,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
173173
self.add_pass(UnsqueezeBeforeRepeatPass())
174174
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
175175
self.add_pass(DecomposeSumPass())
176-
self.add_pass(Conv1dUnsqueezePass(exported_program))
176+
self.add_pass(Conv1dUnsqueezePass())
177177
self.add_pass(DecomposeSelectPass())
178178
self.add_pass(ConvertSqueezesToViewPass())
179179

Lines changed: 42 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,148 +1,67 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# pyre-unsafe
98

10-
11-
import torch
12-
from executorch.backends.arm._passes.arm_pass_utils import (
13-
create_node,
14-
get_param_tensor,
15-
is_param_node,
16-
)
17-
from executorch.exir import ExportedProgram
189
from executorch.exir.dialects._ops import ops as exir_ops
19-
from executorch.exir.pass_base import ExportPass, PassResult
10+
from executorch.exir.pass_base import ExportPass
2011

2112

2213
class Conv1dUnsqueezePass(ExportPass):
2314
"""
2415
This pass is used to change conv1d ops into conv2d since TOSA only
2516
supports 2d and 3d convolution. This is done by modifying the graph to do the
2617
following:
27-
1) unsqueeze the convolution's input from 3d to 4d
18+
1a) unsqueeze the convolution's input from 3d to 4d
19+
1b) unsqueeze the convolution's weight from 3d to 4d
2820
2) perform a conv2d (with a modified version of the original conv1d args)
2921
3) squeeze the output back down to 3d.
3022
"""
3123

32-
def __init__(self, exported_program: ExportedProgram) -> None:
33-
super().__init__()
34-
self.exported_program = exported_program
35-
36-
def unsqueeze_kernel_weights(self, kernel_node):
37-
"""
38-
Unsqueezes the weights of a conv1d to make it 4 dimensional.
39-
40-
Args:
41-
kernel_node: the weights of conv1d node to be unsqueezed
42-
"""
43-
kernel_param_3d = get_param_tensor(self.exported_program, kernel_node)
44-
if kernel_param_3d is None:
45-
raise AssertionError("Expected param tensor for the kernel node")
46-
47-
kernel_param_4d = torch.nn.Parameter(
48-
data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1),
49-
requires_grad=False,
24+
def call_operator(self, op, args, kwargs, meta):
25+
if op != exir_ops.edge.aten.convolution.default:
26+
return super().call_operator(op, args, kwargs, meta)
27+
stride = list(args[3])
28+
if len(stride) != 1:
29+
return super().call_operator(op, args, kwargs, meta)
30+
31+
x = args[0]
32+
x_unsqueezed_shape = list(x.data.shape) + [1]
33+
x = super().call_operator(
34+
exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), {}, meta
5035
)
5136

52-
if torch._export.utils.is_param(self.exported_program, kernel_node):
53-
parameter_name = self.exported_program.graph_signature.inputs_to_parameters[
54-
kernel_node.name
55-
]
56-
self.exported_program.state_dict[parameter_name] = kernel_param_4d
57-
kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
58-
elif torch._export.utils.is_buffer(self.exported_program, kernel_node):
59-
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
60-
kernel_node.name
61-
]
62-
self.exported_program.state_dict[buffer_name] = kernel_param_4d
63-
kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
64-
elif torch._export.utils.is_lifted_tensor_constant(
65-
self.exported_program, kernel_node
66-
):
67-
buffer_name = (
68-
self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
69-
kernel_node.name
70-
]
71-
)
72-
self.exported_program.constants[buffer_name] = kernel_param_4d
73-
kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
74-
else:
75-
setattr(
76-
kernel_node.graph.owning_module,
77-
kernel_node.target,
78-
kernel_param_4d,
79-
)
80-
81-
def call(self, graph_module: torch.fx.GraphModule):
82-
graph = graph_module.graph
83-
node_list = list(graph.nodes)
84-
for node in node_list:
85-
if node.op == "call_function":
86-
if node.target == exir_ops.edge.aten.convolution.default:
87-
stride = list(node.args[3])
88-
if len(stride) != 1:
89-
# skip conv if it is not 1d
90-
continue
91-
92-
kernel_node = node.args[1]
93-
94-
if not is_param_node(self.exported_program, kernel_node):
95-
raise AssertionError(
96-
"Expected op for convolution weight node to be a get_attr node or a parameter"
97-
)
37+
w_meta = meta.copy()
38+
w_meta.data["input_qparams"] = {}
39+
w_meta.data["output_qparams"] = {}
9840

99-
# Modify graph such that the conv changes from 1d to 2d
100-
self.unsqueeze_kernel_weights(kernel_node)
101-
102-
# (b) Extend stride, padding, and dilation for extra dim
103-
node.args = (
104-
node.args[0],
105-
node.args[1],
106-
node.args[2],
107-
node.args[3] + [1], # stride
108-
node.args[4] + [0], # padding
109-
node.args[5] + [1], # dilation
110-
node.args[6],
111-
node.args[7] + [0],
112-
node.args[8],
113-
)
114-
115-
# c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d)
116-
# unsqueeze -> conv2d -> squeeze
117-
with graph.inserting_before(node):
118-
input_node = node.args[0]
119-
unsqueeze_before = create_node(
120-
graph, exir_ops.edge.aten.unsqueeze_copy.default
121-
)
122-
unsqueeze_before.args = (
123-
input_node, # Input is node's original input
124-
-1, # Last Dimension
125-
)
126-
node.replace_input_with(input_node, unsqueeze_before)
41+
w = args[1]
42+
w_unsqueezed_shape = list(w.data.shape) + [1]
43+
w = super().call_operator(
44+
exir_ops.edge.aten.view_copy.default, (w, w_unsqueezed_shape), {}, w_meta
45+
)
12746

128-
with graph.inserting_after(node):
129-
squeeze_after = create_node(
130-
graph,
131-
exir_ops.edge.aten.squeeze_copy.dims,
132-
)
133-
squeeze_after.args = (
134-
node, # Input is the conv node
135-
[-1], # Last dimension
136-
)
137-
original_users = [
138-
user for user in node.users if user != squeeze_after
139-
]
140-
for user in original_users:
141-
user.replace_input_with(node, squeeze_after)
47+
new_args = (
48+
x,
49+
w,
50+
args[2],
51+
args[3] + [1], # stride
52+
args[4] + [0], # padding
53+
args[5] + [1], # dilation
54+
args[6],
55+
args[7] + [0],
56+
args[8],
57+
)
58+
x = super().call_operator(
59+
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta
60+
)
14261

143-
graph_module.recompile()
144-
# Since we are overriding "call", we need to call the parent's "call"
145-
# to retrace the graph and regenerate metadata
146-
graph_module = super().call(graph_module).graph_module
62+
x_squeezed_shape = list(x.data.shape)[:-1]
63+
x = super().call_operator(
64+
exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), {}, meta
65+
)
14766

148-
return PassResult(graph_module, True)
67+
return x

0 commit comments

Comments
 (0)