|
6 | 6 |
|
7 | 7 | # pyre-strict
|
8 | 8 |
|
| 9 | +from typing import Tuple |
| 10 | + |
9 | 11 | import torch
|
10 |
| -from executorch.exir.pass_base import ExportPass, ProxyValue |
11 |
| -from torch.utils import _pytree as pytree |
| 12 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 13 | +from torch.fx import GraphModule |
| 14 | + |
| 15 | +_DEQUANT_OPS: Tuple[torch._ops.OpOverload] = ( |
| 16 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| 17 | + torch.ops.quantized_decomposed.dequantize_per_channel.default, |
| 18 | +) |
| 19 | +_QUANT_OPS: Tuple[torch._ops.OpOverload] = ( |
| 20 | + torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| 21 | + torch.ops.quantized_decomposed.quantize_per_channel.default, |
| 22 | +) |
12 | 23 |
|
13 | 24 |
|
14 | 25 | class RemoveNoopPass(ExportPass):
|
15 | 26 | """
|
16 | 27 | Removes noops that pass through arguments.
|
17 | 28 | """
|
18 | 29 |
|
19 |
| - # pyre-ignore |
20 |
| - def call_operator(self, op, args, kwargs, meta): |
21 |
| - if op not in ( |
22 |
| - torch.ops.aten.to.dtype, |
23 |
| - torch.ops.aten.dropout.default, |
24 |
| - torch.ops.aten.slice_copy.Tensor, |
25 |
| - ): |
26 |
| - return super().call_operator(op, args, kwargs, meta) |
27 |
| - |
28 |
| - args_data, kwargs_data = pytree.tree_map_only( |
29 |
| - ProxyValue, lambda x: x.data, (args, kwargs) |
30 |
| - ) |
31 |
| - orig_tensor = ( |
32 |
| - args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] |
33 |
| - ) |
34 |
| - |
35 |
| - if orig_tensor is op(*args_data, **kwargs_data): |
36 |
| - return args[0] |
37 |
| - |
38 |
| - if op == torch.ops.aten.slice_copy.Tensor: |
39 |
| - result = op(*args_data, **kwargs_data) |
40 |
| - if orig_tensor.size() == result.size(): |
41 |
| - return args[0] |
42 |
| - |
43 |
| - return super().call_operator(op, args, kwargs, meta) |
| 30 | + def remove_quantized_op( |
| 31 | + self, graph_module: GraphModule, node: torch.fx.Node |
| 32 | + ) -> None: |
| 33 | + node_input = list(node.args)[0] |
| 34 | + |
| 35 | + if not isinstance(node_input, torch.fx.Node): |
| 36 | + return |
| 37 | + |
| 38 | + # Let's assume that when entering this section of code the graph pattern is as follows: |
| 39 | + # Node A -> DQ -> slice_copy -> Q -> Node B. If the qparams of the DQ and Q are the same, |
| 40 | + # then after this the graph will look like this: |
| 41 | + # Node A -> Node B. |
| 42 | + if node_input.target in _DEQUANT_OPS: |
| 43 | + for user in list(node.users): |
| 44 | + if user.target in _QUANT_OPS: |
| 45 | + # Drop the input arg and check that the qparams are the same. |
| 46 | + qparams_dq = list(node_input.args)[1:] |
| 47 | + qparams_q = list(user.args)[1:] |
| 48 | + if qparams_dq != qparams_q: |
| 49 | + return |
| 50 | + user.replace_all_uses_with(node_input.args[0]) |
| 51 | + |
| 52 | + def call(self, graph_module: GraphModule) -> PassResult: |
| 53 | + for node in graph_module.graph.nodes: |
| 54 | + if node.op != "call_function": |
| 55 | + continue |
| 56 | + |
| 57 | + if node.target not in ( |
| 58 | + torch.ops.aten.to.dtype, |
| 59 | + torch.ops.aten.dropout.default, |
| 60 | + torch.ops.aten.slice_copy.Tensor, |
| 61 | + ): |
| 62 | + continue |
| 63 | + |
| 64 | + orig_tensor = node.args[0].meta["val"] |
| 65 | + |
| 66 | + if orig_tensor is node.meta["val"]: |
| 67 | + # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q. |
| 68 | + # Otherwise, removing only the op will suffice. |
| 69 | + if node.args[0].target in _DEQUANT_OPS: |
| 70 | + self.remove_quantized_op(graph_module, node) |
| 71 | + else: |
| 72 | + node.replace_all_uses_with(node.args[0]) |
| 73 | + continue |
| 74 | + |
| 75 | + if node.target == torch.ops.aten.slice_copy.Tensor: |
| 76 | + if orig_tensor.size() == node.meta["val"].size(): |
| 77 | + # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q. |
| 78 | + # Otherwise, removing only the op will suffice. |
| 79 | + if node.args[0].target in _DEQUANT_OPS: |
| 80 | + self.remove_quantized_op(graph_module, node) |
| 81 | + else: |
| 82 | + node.replace_all_uses_with(node.args[0]) |
| 83 | + |
| 84 | + graph_module.graph.lint() |
| 85 | + graph_module.graph.eliminate_dead_code() |
| 86 | + return PassResult(graph_module, True) |
0 commit comments