Skip to content

Commit 68d29f7

Browse files
tarun292facebook-github-bot
authored andcommitted
Add fuse_dq_q_pass in exir/passes and also add it to HTP backend (#2295)
Summary: There are passes such as these https://fburl.com/code/vs6n4vcv that remove noops from the graph. The problem is that after this pass runs it still leaves in the dq->q nodes in the graph which then potentially get delegated to the backend causing perf regressions. This pass will remove the dq->q ops if their qparams are of the same value. If not it won't touch them. Differential Revision: D54543323
1 parent 341d2d9 commit 68d29f7

File tree

2 files changed

+140
-27
lines changed

2 files changed

+140
-27
lines changed

exir/passes/remove_noop_pass.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,81 @@
66

77
# pyre-strict
88

9+
from typing import Tuple
10+
911
import torch
10-
from executorch.exir.pass_base import ExportPass, ProxyValue
11-
from torch.utils import _pytree as pytree
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.fx import GraphModule
14+
15+
_DEQUANT_OPS: Tuple[torch._ops.OpOverload] = (
16+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
17+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
18+
)
19+
_QUANT_OPS: Tuple[torch._ops.OpOverload] = (
20+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
21+
torch.ops.quantized_decomposed.quantize_per_channel.default,
22+
)
1223

1324

1425
class RemoveNoopPass(ExportPass):
1526
"""
1627
Removes noops that pass through arguments.
1728
"""
1829

19-
# pyre-ignore
20-
def call_operator(self, op, args, kwargs, meta):
21-
if op not in (
22-
torch.ops.aten.to.dtype,
23-
torch.ops.aten.dropout.default,
24-
torch.ops.aten.slice_copy.Tensor,
25-
):
26-
return super().call_operator(op, args, kwargs, meta)
27-
28-
args_data, kwargs_data = pytree.tree_map_only(
29-
ProxyValue, lambda x: x.data, (args, kwargs)
30-
)
31-
orig_tensor = (
32-
args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
33-
)
34-
35-
if orig_tensor is op(*args_data, **kwargs_data):
36-
return args[0]
37-
38-
if op == torch.ops.aten.slice_copy.Tensor:
39-
result = op(*args_data, **kwargs_data)
40-
if orig_tensor.size() == result.size():
41-
return args[0]
42-
43-
return super().call_operator(op, args, kwargs, meta)
30+
def remove_quantized_op(
31+
self, graph_module: GraphModule, node: torch.fx.Node
32+
) -> None:
33+
node_input = list(node.args)[0]
34+
35+
if not isinstance(node_input, torch.fx.Node):
36+
return
37+
38+
# Let's assume that when entering this section of code the graph pattern is as follows:
39+
# Node A -> DQ -> slice_copy -> Q -> Node B. If the qparams of the DQ and Q are the same,
40+
# then after this the graph will look like this:
41+
# Node A -> Node B.
42+
if node_input.target in _DEQUANT_OPS:
43+
for user in list(node.users):
44+
if user.target in _QUANT_OPS:
45+
# Drop the input arg and check that the qparams are the same.
46+
qparams_dq = list(node_input.args)[1:]
47+
qparams_q = list(user.args)[1:]
48+
if qparams_dq != qparams_q:
49+
return
50+
user.replace_all_uses_with(node_input.args[0])
51+
52+
def call(self, graph_module: GraphModule) -> PassResult:
53+
for node in graph_module.graph.nodes:
54+
if node.op != "call_function":
55+
continue
56+
57+
if node.target not in (
58+
torch.ops.aten.to.dtype,
59+
torch.ops.aten.dropout.default,
60+
torch.ops.aten.slice_copy.Tensor,
61+
):
62+
continue
63+
64+
orig_tensor = node.args[0].meta["val"]
65+
66+
if orig_tensor is node.meta["val"]:
67+
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
68+
# Otherwise, removing only the op will suffice.
69+
if node.args[0].target in _DEQUANT_OPS:
70+
self.remove_quantized_op(graph_module, node)
71+
else:
72+
node.replace_all_uses_with(node.args[0])
73+
continue
74+
75+
if node.target == torch.ops.aten.slice_copy.Tensor:
76+
if orig_tensor.size() == node.meta["val"].size():
77+
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
78+
# Otherwise, removing only the op will suffice.
79+
if node.args[0].target in _DEQUANT_OPS:
80+
self.remove_quantized_op(graph_module, node)
81+
else:
82+
node.replace_all_uses_with(node.args[0])
83+
84+
graph_module.graph.lint()
85+
graph_module.graph.eliminate_dead_code()
86+
return PassResult(graph_module, True)

exir/tests/test_passes.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050
from functorch.experimental import control_flow
5151

5252
from torch import nn
53+
54+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
55+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
56+
get_symmetric_quantization_config,
57+
XNNPACKQuantizer,
58+
)
5359
from torch.export import export
5460
from torch.fx import GraphModule, subgraph_rewriter
5561
from torch.fx.experimental.proxy_tensor import make_fx
@@ -1244,3 +1250,67 @@ def forward(self, x):
12441250
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
12451251
# return (copy__default, aten_add_tensor)
12461252
self.assertEqual(count_copies(gm), 1)
1253+
1254+
def test_remove_quantized_op_noop_pass(self) -> None:
1255+
class TestAddSlice(torch.nn.Module):
1256+
def __init__(self):
1257+
super().__init__()
1258+
1259+
def forward(self, x):
1260+
x = x + x
1261+
x = x + x[:]
1262+
return x
1263+
1264+
def count_dq_nodes(gm: torch.fx.GraphModule) -> int:
1265+
return sum(
1266+
(
1267+
node.target
1268+
in (
1269+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1270+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
1271+
)
1272+
)
1273+
for node in gm.graph.nodes
1274+
)
1275+
1276+
def count_q_nodes(gm: torch.fx.GraphModule) -> int:
1277+
return sum(
1278+
(
1279+
node.target
1280+
in (
1281+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
1282+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1283+
)
1284+
)
1285+
for node in gm.graph.nodes
1286+
)
1287+
1288+
example_inputs = (torch.randn(9, 8),)
1289+
model = TestAddSlice()
1290+
m_eager = model.eval()
1291+
1292+
# program capture
1293+
m = torch._export.capture_pre_autograd_graph(
1294+
m_eager,
1295+
example_inputs,
1296+
)
1297+
1298+
quantizer = XNNPACKQuantizer()
1299+
quantization_config = get_symmetric_quantization_config()
1300+
quantizer.set_global(quantization_config)
1301+
m = prepare_pt2e(m, quantizer)
1302+
m = convert_pt2e(m, fold_quantize=True)
1303+
ep = torch.export.export(m, example_inputs)
1304+
dq_nodes_pre = count_dq_nodes(ep.graph_module)
1305+
q_nodes_pre = count_q_nodes(ep.graph_module)
1306+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False))
1307+
1308+
dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1309+
q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1310+
# One dq and one q node around the slice copy should have been removed.
1311+
self.assertEqual(dq_nodes_pre - dq_nodes_post, 1)
1312+
self.assertEqual(q_nodes_pre - q_nodes_post, 1)
1313+
1314+
# Check that the slice_copy is removed by the RemoveNoopPass.
1315+
for node in edge.exported_program().graph_module.graph.nodes:
1316+
self.assertFalse("slice" in str(node.target))

0 commit comments

Comments
 (0)