|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
| -# Copyright 2024 Arm Limited and/or its affiliates. |
3 | 2 | # All rights reserved.
|
| 3 | +# Copyright 2024-2025 Arm Limited and/or its affiliates. |
4 | 4 | #
|
5 | 5 | # This source code is licensed under the BSD-style license found in the
|
6 | 6 | # LICENSE file in the root directory of this source tree.
|
7 | 7 |
|
8 |
| -# pyre-unsafe |
9 | 8 |
|
10 |
| - |
11 |
| -import torch |
12 |
| -from executorch.backends.arm._passes.arm_pass_utils import ( |
13 |
| - create_node, |
14 |
| - get_param_tensor, |
15 |
| - is_param_node, |
16 |
| -) |
17 |
| -from executorch.exir import ExportedProgram |
18 | 9 | from executorch.exir.dialects._ops import ops as exir_ops
|
19 |
| -from executorch.exir.pass_base import ExportPass, PassResult |
| 10 | +from executorch.exir.pass_base import ExportPass |
20 | 11 |
|
21 | 12 |
|
22 | 13 | class Conv1dUnsqueezePass(ExportPass):
|
23 | 14 | """
|
24 | 15 | This pass is used to change conv1d ops into conv2d since TOSA only
|
25 | 16 | supports 2d and 3d convolution. This is done by modifying the graph to do the
|
26 | 17 | following:
|
27 |
| - 1) unsqueeze the convolution's input from 3d to 4d |
| 18 | + 1a) unsqueeze the convolution's input from 3d to 4d |
| 19 | + 1b) unsqueeze the convolution's weight from 3d to 4d |
28 | 20 | 2) perform a conv2d (with a modified version of the original conv1d args)
|
29 | 21 | 3) squeeze the output back down to 3d.
|
30 | 22 | """
|
31 | 23 |
|
32 |
| - def __init__(self, exported_program: ExportedProgram) -> None: |
33 |
| - super().__init__() |
34 |
| - self.exported_program = exported_program |
35 |
| - |
36 |
| - def unsqueeze_kernel_weights(self, kernel_node): |
37 |
| - """ |
38 |
| - Unsqueezes the weights of a conv1d to make it 4 dimensional. |
39 |
| -
|
40 |
| - Args: |
41 |
| - kernel_node: the weights of conv1d node to be unsqueezed |
42 |
| - """ |
43 |
| - kernel_param_3d = get_param_tensor(self.exported_program, kernel_node) |
44 |
| - if kernel_param_3d is None: |
45 |
| - raise AssertionError("Expected param tensor for the kernel node") |
46 |
| - |
47 |
| - kernel_param_4d = torch.nn.Parameter( |
48 |
| - data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1), |
49 |
| - requires_grad=False, |
| 24 | + def call_operator(self, op, args, kwargs, meta): |
| 25 | + if op != exir_ops.edge.aten.convolution.default: |
| 26 | + return super().call_operator(op, args, kwargs, meta) |
| 27 | + stride = list(args[3]) |
| 28 | + if len(stride) != 1: |
| 29 | + return super().call_operator(op, args, kwargs, meta) |
| 30 | + |
| 31 | + x = args[0] |
| 32 | + x_unsqueezed_shape = list(x.data.shape) + [1] |
| 33 | + x = super().call_operator( |
| 34 | + exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), {}, meta |
50 | 35 | )
|
51 | 36 |
|
52 |
| - if torch._export.utils.is_param(self.exported_program, kernel_node): |
53 |
| - parameter_name = self.exported_program.graph_signature.inputs_to_parameters[ |
54 |
| - kernel_node.name |
55 |
| - ] |
56 |
| - self.exported_program.state_dict[parameter_name] = kernel_param_4d |
57 |
| - kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) |
58 |
| - elif torch._export.utils.is_buffer(self.exported_program, kernel_node): |
59 |
| - buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ |
60 |
| - kernel_node.name |
61 |
| - ] |
62 |
| - self.exported_program.state_dict[buffer_name] = kernel_param_4d |
63 |
| - kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) |
64 |
| - elif torch._export.utils.is_lifted_tensor_constant( |
65 |
| - self.exported_program, kernel_node |
66 |
| - ): |
67 |
| - buffer_name = ( |
68 |
| - self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[ |
69 |
| - kernel_node.name |
70 |
| - ] |
71 |
| - ) |
72 |
| - self.exported_program.constants[buffer_name] = kernel_param_4d |
73 |
| - kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) |
74 |
| - else: |
75 |
| - setattr( |
76 |
| - kernel_node.graph.owning_module, |
77 |
| - kernel_node.target, |
78 |
| - kernel_param_4d, |
79 |
| - ) |
80 |
| - |
81 |
| - def call(self, graph_module: torch.fx.GraphModule): |
82 |
| - graph = graph_module.graph |
83 |
| - node_list = list(graph.nodes) |
84 |
| - for node in node_list: |
85 |
| - if node.op == "call_function": |
86 |
| - if node.target == exir_ops.edge.aten.convolution.default: |
87 |
| - stride = list(node.args[3]) |
88 |
| - if len(stride) != 1: |
89 |
| - # skip conv if it is not 1d |
90 |
| - continue |
91 |
| - |
92 |
| - kernel_node = node.args[1] |
93 |
| - |
94 |
| - if not is_param_node(self.exported_program, kernel_node): |
95 |
| - raise AssertionError( |
96 |
| - "Expected op for convolution weight node to be a get_attr node or a parameter" |
97 |
| - ) |
| 37 | + w_meta = meta.copy() |
| 38 | + w_meta.data["input_qparams"] = {} |
| 39 | + w_meta.data["output_qparams"] = {} |
98 | 40 |
|
99 |
| - # Modify graph such that the conv changes from 1d to 2d |
100 |
| - self.unsqueeze_kernel_weights(kernel_node) |
101 |
| - |
102 |
| - # (b) Extend stride, padding, and dilation for extra dim |
103 |
| - node.args = ( |
104 |
| - node.args[0], |
105 |
| - node.args[1], |
106 |
| - node.args[2], |
107 |
| - node.args[3] + [1], # stride |
108 |
| - node.args[4] + [0], # padding |
109 |
| - node.args[5] + [1], # dilation |
110 |
| - node.args[6], |
111 |
| - node.args[7] + [0], |
112 |
| - node.args[8], |
113 |
| - ) |
114 |
| - |
115 |
| - # c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d) |
116 |
| - # unsqueeze -> conv2d -> squeeze |
117 |
| - with graph.inserting_before(node): |
118 |
| - input_node = node.args[0] |
119 |
| - unsqueeze_before = create_node( |
120 |
| - graph, exir_ops.edge.aten.unsqueeze_copy.default |
121 |
| - ) |
122 |
| - unsqueeze_before.args = ( |
123 |
| - input_node, # Input is node's original input |
124 |
| - -1, # Last Dimension |
125 |
| - ) |
126 |
| - node.replace_input_with(input_node, unsqueeze_before) |
| 41 | + w = args[1] |
| 42 | + w_unsqueezed_shape = list(w.data.shape) + [1] |
| 43 | + w = super().call_operator( |
| 44 | + exir_ops.edge.aten.view_copy.default, (w, w_unsqueezed_shape), {}, w_meta |
| 45 | + ) |
127 | 46 |
|
128 |
| - with graph.inserting_after(node): |
129 |
| - squeeze_after = create_node( |
130 |
| - graph, |
131 |
| - exir_ops.edge.aten.squeeze_copy.dims, |
132 |
| - ) |
133 |
| - squeeze_after.args = ( |
134 |
| - node, # Input is the conv node |
135 |
| - [-1], # Last dimension |
136 |
| - ) |
137 |
| - original_users = [ |
138 |
| - user for user in node.users if user != squeeze_after |
139 |
| - ] |
140 |
| - for user in original_users: |
141 |
| - user.replace_input_with(node, squeeze_after) |
| 47 | + new_args = ( |
| 48 | + x, |
| 49 | + w, |
| 50 | + args[2], |
| 51 | + args[3] + [1], # stride |
| 52 | + args[4] + [0], # padding |
| 53 | + args[5] + [1], # dilation |
| 54 | + args[6], |
| 55 | + args[7] + [0], |
| 56 | + args[8], |
| 57 | + ) |
| 58 | + x = super().call_operator( |
| 59 | + exir_ops.edge.aten.convolution.default, new_args, kwargs, meta |
| 60 | + ) |
142 | 61 |
|
143 |
| - graph_module.recompile() |
144 |
| - # Since we are overriding "call", we need to call the parent's "call" |
145 |
| - # to retrace the graph and regenerate metadata |
146 |
| - graph_module = super().call(graph_module).graph_module |
| 62 | + x_squeezed_shape = list(x.data.shape)[:-1] |
| 63 | + x = super().call_operator( |
| 64 | + exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), {}, meta |
| 65 | + ) |
147 | 66 |
|
148 |
| - return PassResult(graph_module, True) |
| 67 | + return x |
0 commit comments