|
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 | import torch.fx
|
| 13 | +import torch.nn.functional as F |
13 | 14 | from executorch.backends.arm.quantizer import QuantizationConfig
|
14 | 15 | from executorch.backends.arm.tosa_utils import get_node_debug_info
|
15 | 16 | from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
|
@@ -142,29 +143,33 @@ def _match_pattern(
|
142 | 143 |
|
143 | 144 | Each 'pattern' element is composed of a list of disjunctive nodes types.
|
144 | 145 | """
|
145 |
| - assert len(pattern) == 2, "Only two-nodes patterns supported currently" |
146 |
| - |
147 |
| - if node.target in pattern[0]: |
148 |
| - assert len(node.users) != 0 |
149 |
| - parent = node |
150 |
| - child = next(iter(node.users)) |
151 |
| - elif node.target in pattern[1]: |
152 |
| - assert len(node.args) != 0 |
153 |
| - parent = node.args[0] # type: ignore[assignment] |
154 |
| - child = node |
155 |
| - else: |
156 |
| - return False |
157 |
| - |
158 |
| - if len(parent.users) != 1: |
159 |
| - return False |
160 |
| - |
161 |
| - if parent.target not in pattern[0] or child.target not in pattern[1]: |
162 |
| - return False |
163 |
| - |
| 146 | + assert len(pattern) > 0, "No pattern provided" |
164 | 147 | if filter_fn is not None:
|
165 |
| - return filter_fn(parent) and filter_fn(child) |
166 |
| - |
167 |
| - return True |
| 148 | + if not filter_fn(node): |
| 149 | + return False |
| 150 | + if len(pattern) == 1: |
| 151 | + # Base case where it has passed the filter_fn. Simply look if node.target is in pattern. |
| 152 | + return node.target in pattern[0] |
| 153 | + if node.target not in [op for sub_pattern in pattern for op in sub_pattern]: |
| 154 | + # node.target not in pattern. No need to look at the rest of the pattern. |
| 155 | + return False |
| 156 | + # Find the index of this node's target in pattern |
| 157 | + idx = [node.target in sub_pattern for sub_pattern in pattern].index(True) |
| 158 | + left_pattern = pattern[:idx] |
| 159 | + # Exclude idx as this contains node.target which we have already matched |
| 160 | + right_pattern = pattern[idx + 1 :] |
| 161 | + left_condition = True |
| 162 | + right_condition = True |
| 163 | + # Recursively look at the rest of the pattern by calling this function for |
| 164 | + # node's input and user node with updated patterns. |
| 165 | + if len(left_pattern) > 0: |
| 166 | + parent = node.all_input_nodes[0] |
| 167 | + if len(parent.users) != 1: |
| 168 | + return False |
| 169 | + left_condition = _match_pattern(parent, left_pattern, filter_fn) |
| 170 | + if len(right_pattern) > 0: |
| 171 | + right_condition = _match_pattern(list(node.users)[0], right_pattern, filter_fn) |
| 172 | + return left_condition and right_condition |
168 | 173 |
|
169 | 174 |
|
170 | 175 | _one_to_one = [
|
@@ -274,6 +279,58 @@ def any_or_hardtanh_min_zero(n: Node):
|
274 | 279 | return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
|
275 | 280 |
|
276 | 281 | if _match_pattern(
|
| 282 | + node, |
| 283 | + [ |
| 284 | + [ |
| 285 | + torch.ops.aten.conv1d.default, |
| 286 | + torch.ops.aten.conv2d.default, |
| 287 | + torch.ops.aten.conv2d.padding, |
| 288 | + ], |
| 289 | + [torch.ops.aten.batch_norm.default, F.batch_norm], |
| 290 | + [torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default], |
| 291 | + ], |
| 292 | + filter_fn=any_or_hardtanh_min_zero, |
| 293 | + ): |
| 294 | + if node.target in ( |
| 295 | + torch.ops.aten.conv1d.default, |
| 296 | + torch.ops.aten.conv2d.default, |
| 297 | + torch.ops.aten.conv2d.padding, |
| 298 | + ): |
| 299 | + quant_properties.quant_inputs = [ |
| 300 | + _QuantProperty(0, input_act_qspec), |
| 301 | + _QuantProperty(1, weight_qspec, mark_annotated=True), |
| 302 | + _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), |
| 303 | + ] |
| 304 | + elif node.target in ( |
| 305 | + torch.ops.aten.relu.default, |
| 306 | + torch.ops.aten.hardtanh.default, |
| 307 | + ): |
| 308 | + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) |
| 309 | + |
| 310 | + elif _match_pattern( |
| 311 | + node, |
| 312 | + [ |
| 313 | + [ |
| 314 | + torch.ops.aten.conv1d.default, |
| 315 | + torch.ops.aten.conv2d.default, |
| 316 | + torch.ops.aten.conv2d.padding, |
| 317 | + ], |
| 318 | + [torch.ops.aten.batch_norm.default, F.batch_norm], |
| 319 | + ], |
| 320 | + ): |
| 321 | + if node.target in ( |
| 322 | + torch.ops.aten.conv1d.default, |
| 323 | + torch.ops.aten.conv2d.default, |
| 324 | + torch.ops.aten.conv2d.padding, |
| 325 | + ): |
| 326 | + quant_properties.quant_inputs = [ |
| 327 | + _QuantProperty(0, input_act_qspec), |
| 328 | + _QuantProperty(1, weight_qspec, mark_annotated=True), |
| 329 | + _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), |
| 330 | + ] |
| 331 | + elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]: |
| 332 | + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) |
| 333 | + elif _match_pattern( |
277 | 334 | node,
|
278 | 335 | [
|
279 | 336 | [
|
|
0 commit comments