Skip to content

Commit e860f35

Browse files
tarun292facebook-github-bot
authored andcommitted
Add fuse_dq_q_pass in exir/passes and also add it to HTP backend (#2295)
Summary: There are passes such as these https://fburl.com/code/vs6n4vcv that remove noops from the graph. The problem is that after this pass runs it still leaves in the dq->q nodes in the graph which then potentially get delegated to the backend causing perf regressions. This pass will remove the dq->q ops if their qparams are of the same value. If not it won't touch them. Differential Revision: D54543323
1 parent 42eeebc commit e860f35

File tree

2 files changed

+138
-27
lines changed

2 files changed

+138
-27
lines changed

exir/passes/remove_noop_pass.py

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,78 @@
77
# pyre-strict
88

99
import torch
10-
from executorch.exir.pass_base import ExportPass, ProxyValue
11-
from torch.utils import _pytree as pytree
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from torch.fx import GraphModule
12+
13+
_DEQUANT_OPS = (
14+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
15+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
16+
)
17+
_QUANT_OPS = (
18+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
19+
torch.ops.quantized_decomposed.quantize_per_channel.default,
20+
)
1221

1322

1423
class RemoveNoopPass(ExportPass):
1524
"""
1625
Removes noops that pass through arguments.
1726
"""
1827

19-
# pyre-ignore
20-
def call_operator(self, op, args, kwargs, meta):
21-
if op not in (
22-
torch.ops.aten.to.dtype,
23-
torch.ops.aten.dropout.default,
24-
torch.ops.aten.slice_copy.Tensor,
25-
):
26-
return super().call_operator(op, args, kwargs, meta)
27-
28-
args_data, kwargs_data = pytree.tree_map_only(
29-
ProxyValue, lambda x: x.data, (args, kwargs)
30-
)
31-
orig_tensor = (
32-
args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
33-
)
34-
35-
if orig_tensor is op(*args_data, **kwargs_data):
36-
return args[0]
37-
38-
if op == torch.ops.aten.slice_copy.Tensor:
39-
result = op(*args_data, **kwargs_data)
40-
if orig_tensor.size() == result.size():
41-
return args[0]
42-
43-
return super().call_operator(op, args, kwargs, meta)
28+
def remove_quantized_op(
29+
self, graph_module: GraphModule, node: torch.fx.Node
30+
) -> None:
31+
node_input = list(node.args)[0]
32+
33+
if not isinstance(node_input, torch.fx.Node):
34+
return
35+
36+
# Let's assume that when entering this section of code the graph pattern is as follows:
37+
# Node A -> DQ -> slice_copy -> Q -> Node B. If the qparams of the DQ and Q are the same,
38+
# then after this the graph will look like this:
39+
# Node A -> Node B.
40+
if node_input.target in _DEQUANT_OPS:
41+
for user in list(node.users):
42+
if user.target in _QUANT_OPS:
43+
# Drop the input arg and check that the qparams are the same.
44+
qparams_dq = list(node_input.args)[1:]
45+
qparams_q = list(user.args)[1:]
46+
if qparams_dq != qparams_q:
47+
return
48+
user.replace_all_uses_with(node_input.args[0])
49+
50+
def call(self, graph_module: GraphModule) -> PassResult:
51+
for node in graph_module.graph.nodes:
52+
if node.op != "call_function":
53+
continue
54+
55+
if node.target not in (
56+
torch.ops.aten.to.dtype,
57+
torch.ops.aten.dropout.default,
58+
torch.ops.aten.slice_copy.Tensor,
59+
):
60+
continue
61+
62+
orig_tensor = node.args[0].meta["val"]
63+
64+
if orig_tensor is node.meta["val"]:
65+
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
66+
# Otherwise, removing only the op will suffice.
67+
if node.args[0].target in _DEQUANT_OPS:
68+
self.remove_quantized_op(graph_module, node)
69+
else:
70+
node.replace_all_uses_with(node.args[0])
71+
continue
72+
73+
if node.target == torch.ops.aten.slice_copy.Tensor:
74+
if orig_tensor.size() == node.meta["val"].size():
75+
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
76+
# Otherwise, removing only the op will suffice.
77+
if node.args[0].target in _DEQUANT_OPS:
78+
self.remove_quantized_op(graph_module, node)
79+
else:
80+
node.replace_all_uses_with(node.args[0])
81+
82+
graph_module.graph.lint()
83+
graph_module.graph.eliminate_dead_code()
84+
return PassResult(graph_module, True)

exir/tests/test_passes.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050
from functorch.experimental import control_flow
5151

5252
from torch import nn
53+
54+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
55+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
56+
get_symmetric_quantization_config,
57+
XNNPACKQuantizer,
58+
)
5359
from torch.export import export
5460
from torch.fx import GraphModule, subgraph_rewriter
5561
from torch.fx.experimental.proxy_tensor import make_fx
@@ -1244,3 +1250,67 @@ def forward(self, x):
12441250
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
12451251
# return (copy__default, aten_add_tensor)
12461252
self.assertEqual(count_copies(gm), 1)
1253+
1254+
def test_remove_quantized_op_noop_pass(self) -> None:
1255+
class TestAddSlice(torch.nn.Module):
1256+
def __init__(self):
1257+
super().__init__()
1258+
1259+
def forward(self, x):
1260+
x = x + x
1261+
x = x + x[:]
1262+
return x
1263+
1264+
def count_dq_nodes(gm: torch.fx.GraphModule) -> int:
1265+
return sum(
1266+
(
1267+
node.target
1268+
in (
1269+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1270+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
1271+
)
1272+
)
1273+
for node in gm.graph.nodes
1274+
)
1275+
1276+
def count_q_nodes(gm: torch.fx.GraphModule) -> int:
1277+
return sum(
1278+
(
1279+
node.target
1280+
in (
1281+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
1282+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1283+
)
1284+
)
1285+
for node in gm.graph.nodes
1286+
)
1287+
1288+
example_inputs = (torch.randn(9, 8),)
1289+
model = TestAddSlice()
1290+
m_eager = model.eval()
1291+
1292+
# program capture
1293+
m = torch._export.capture_pre_autograd_graph(
1294+
m_eager,
1295+
example_inputs,
1296+
)
1297+
1298+
quantizer = XNNPACKQuantizer()
1299+
quantization_config = get_symmetric_quantization_config()
1300+
quantizer.set_global(quantization_config)
1301+
m = prepare_pt2e(m, quantizer)
1302+
m = convert_pt2e(m, fold_quantize=True)
1303+
ep = torch.export.export(m, example_inputs)
1304+
dq_nodes_pre = count_dq_nodes(ep.graph_module)
1305+
q_nodes_pre = count_q_nodes(ep.graph_module)
1306+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False))
1307+
1308+
dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1309+
q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1310+
# One dq and one q node around the slice copy should have been removed.
1311+
self.assertEqual(dq_nodes_pre - dq_nodes_post, 1)
1312+
self.assertEqual(q_nodes_pre - q_nodes_post, 1)
1313+
1314+
# Check that the slice_copy is removed by the RemoveNoopPass.
1315+
for node in edge.exported_program().graph_module.graph.nodes:
1316+
self.assertFalse("slice" in str(node.target))

0 commit comments

Comments
 (0)