Skip to content

Commit 1a5f23d

Browse files
tarun292facebook-github-bot
authored andcommitted
Add fuse_dq_q_pass in exir/passes and also add it to HTP backend (#2295)
Summary: There are passes such as these https://fburl.com/code/vs6n4vcv that remove noops from the graph. The problem is that after this pass runs it still leaves in the dq->q nodes in the graph which then potentially get delegated to the backend causing perf regressions. This pass will remove the dq->q ops if their qparams are of the same value. If not it won't touch them. Reviewed By: kimishpatel Differential Revision: D54543323
1 parent 091e524 commit 1a5f23d

File tree

3 files changed

+254
-28
lines changed

3 files changed

+254
-28
lines changed

exir/passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ python_library(
154154
deps = [
155155
"//caffe2:torch",
156156
"//executorch/exir:pass_base",
157+
"//executorch/exir/dialects:lib",
157158
],
158159
)
159160

exir/passes/remove_noop_pass.py

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,87 @@
66

77
# pyre-strict
88

9+
from typing import List, Tuple
10+
911
import torch
10-
from executorch.exir.pass_base import ExportPass, ProxyValue
11-
from torch.utils import _pytree as pytree
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.fx import GraphModule
15+
16+
_DEQUANT_OPS: Tuple[torch._ops.OpOverload] = (
17+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
18+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
19+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
20+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
21+
)
22+
_QUANT_OPS: Tuple[torch._ops.OpOverload] = (
23+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
24+
torch.ops.quantized_decomposed.quantize_per_channel.default,
25+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
26+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
27+
)
28+
29+
30+
def eliminate_dq_q(
31+
graph_module: GraphModule,
32+
dequant_nodes: List[torch.fx.Node],
33+
) -> None:
34+
for node in dequant_nodes:
35+
assert node.target in _DEQUANT_OPS
36+
for user in list(node.users):
37+
if user.target in _QUANT_OPS:
38+
# Drop the input arg and check that the qparams are the same.
39+
qparams_dq = list(node.args)[1:]
40+
qparams_q = list(user.args)[1:]
41+
if qparams_dq != qparams_q:
42+
continue
43+
user.replace_all_uses_with(node.args[0])
1244

1345

1446
class RemoveNoopPass(ExportPass):
1547
"""
1648
Removes noops that pass through arguments.
1749
"""
1850

19-
# pyre-ignore
20-
def call_operator(self, op, args, kwargs, meta):
21-
if op not in (
22-
torch.ops.aten.to.dtype,
23-
torch.ops.aten.dropout.default,
24-
torch.ops.aten.slice_copy.Tensor,
25-
):
26-
return super().call_operator(op, args, kwargs, meta)
27-
28-
args_data, kwargs_data = pytree.tree_map_only(
29-
ProxyValue, lambda x: x.data, (args, kwargs)
30-
)
31-
orig_tensor = (
32-
args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
33-
)
34-
35-
if orig_tensor is op(*args_data, **kwargs_data):
36-
return args[0]
37-
38-
if op == torch.ops.aten.slice_copy.Tensor:
39-
result = op(*args_data, **kwargs_data)
40-
if orig_tensor.size() == result.size():
41-
return args[0]
42-
43-
return super().call_operator(op, args, kwargs, meta)
51+
def call(self, graph_module: GraphModule) -> PassResult:
52+
53+
# In this list we'll collect all the dequant nodes that are inputs to ops that
54+
# are removed in this pass and later check for redundant dq->q patterns and
55+
# remove them.
56+
dequant_nodes = []
57+
58+
for node in graph_module.graph.nodes:
59+
if node.op != "call_function":
60+
continue
61+
62+
if node.target not in (
63+
torch.ops.aten.to.dtype,
64+
torch.ops.aten.dropout.default,
65+
torch.ops.aten.slice_copy.Tensor,
66+
):
67+
continue
68+
69+
orig_tensor = node.args[0].meta["val"]
70+
71+
if orig_tensor is node.meta["val"]:
72+
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
73+
# Otherwise, removing only the op will suffice.
74+
if node.args[0].target in _DEQUANT_OPS:
75+
dequant_nodes += [node.args[0]]
76+
node.replace_all_uses_with(node.args[0])
77+
continue
78+
79+
if node.target == torch.ops.aten.slice_copy.Tensor:
80+
if orig_tensor.size() == node.meta["val"].size():
81+
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
82+
# Otherwise, removing only the op will suffice.
83+
if node.args[0].target in _DEQUANT_OPS:
84+
dequant_nodes += [node.args[0]]
85+
node.replace_all_uses_with(node.args[0])
86+
87+
graph_module.graph.eliminate_dead_code()
88+
eliminate_dq_q(graph_module, dequant_nodes)
89+
graph_module.graph.lint()
90+
graph_module.graph.eliminate_dead_code()
91+
92+
return PassResult(graph_module, True)

exir/tests/test_passes.py

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Import passes
1616
import executorch.exir.memory_planning # noqa
1717
import torch
18-
from executorch.exir import EdgeCompileConfig, memory, to_edge
18+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
1919
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
2020
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2121
from executorch.exir.emit import emit_program
@@ -50,6 +50,12 @@
5050
from functorch.experimental import control_flow
5151

5252
from torch import nn
53+
54+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
55+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
56+
get_symmetric_quantization_config,
57+
XNNPACKQuantizer,
58+
)
5359
from torch.export import export
5460
from torch.fx import GraphModule, subgraph_rewriter
5561
from torch.fx.experimental.proxy_tensor import make_fx
@@ -1244,3 +1250,173 @@ def forward(self, x):
12441250
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
12451251
# return (copy__default, aten_add_tensor)
12461252
self.assertEqual(count_copies(gm), 1)
1253+
1254+
def test_remove_quantized_op_noop_pass(self) -> None:
1255+
class TestAddSliceNoop(torch.nn.Module):
1256+
def __init__(self):
1257+
super().__init__()
1258+
1259+
def forward(self, x):
1260+
x = x + x
1261+
x = x + x[:]
1262+
return x
1263+
1264+
class TestAddSliceNotNoop(torch.nn.Module):
1265+
def __init__(self):
1266+
super().__init__()
1267+
1268+
def forward(self, x):
1269+
x = x + x
1270+
x = x + x[:1]
1271+
return x
1272+
1273+
def count_dq_nodes(gm: torch.fx.GraphModule) -> int:
1274+
return sum(
1275+
(
1276+
node.target
1277+
in (
1278+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1279+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
1280+
)
1281+
)
1282+
for node in gm.graph.nodes
1283+
)
1284+
1285+
def count_q_nodes(gm: torch.fx.GraphModule) -> int:
1286+
return sum(
1287+
(
1288+
node.target
1289+
in (
1290+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
1291+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1292+
)
1293+
)
1294+
for node in gm.graph.nodes
1295+
)
1296+
1297+
def quantize_model(
1298+
m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor]
1299+
) -> Tuple[EdgeProgramManager, int, int]:
1300+
# program capture
1301+
m = torch._export.capture_pre_autograd_graph(
1302+
m_eager,
1303+
example_inputs,
1304+
)
1305+
1306+
quantizer = XNNPACKQuantizer()
1307+
quantization_config = get_symmetric_quantization_config()
1308+
quantizer.set_global(quantization_config)
1309+
m = prepare_pt2e(m, quantizer)
1310+
m = convert_pt2e(m, fold_quantize=True)
1311+
ep = torch.export.export(m, example_inputs)
1312+
dq_nodes_pre = count_dq_nodes(ep.graph_module)
1313+
q_nodes_pre = count_q_nodes(ep.graph_module)
1314+
edge = to_edge(
1315+
ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
1316+
)
1317+
return edge, dq_nodes_pre, q_nodes_pre
1318+
1319+
example_inputs = (torch.randn(9, 8),)
1320+
model = TestAddSliceNoop()
1321+
m_eager = model.eval()
1322+
edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)
1323+
1324+
dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1325+
q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1326+
# One dq and one q node around the slice copy should have been removed.
1327+
self.assertEqual(dq_nodes_pre - dq_nodes_post, 1)
1328+
self.assertEqual(q_nodes_pre - q_nodes_post, 1)
1329+
1330+
# Check that the slice_copy is removed by the RemoveNoopPass.
1331+
for node in edge.exported_program().graph_module.graph.nodes:
1332+
self.assertFalse("slice" in str(node.target))
1333+
1334+
model = TestAddSliceNotNoop()
1335+
m_eager = model.eval()
1336+
edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)
1337+
1338+
dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1339+
q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1340+
# One dq and one q node around the slice copy should have been removed.
1341+
self.assertEqual(dq_nodes_pre, dq_nodes_post)
1342+
self.assertEqual(q_nodes_pre, q_nodes_post)
1343+
1344+
# Check that the slice_copy is not removed by the RemoveNoopPass.
1345+
self.assertTrue(
1346+
any(
1347+
"slice" in str(node.target)
1348+
for node in edge.exported_program().graph_module.graph.nodes
1349+
)
1350+
)
1351+
1352+
def test_dq_q_no_op_pass(self) -> None:
1353+
class TestDqQ(torch.nn.Module):
1354+
def __init__(self):
1355+
super().__init__()
1356+
1357+
def forward(self, x):
1358+
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1359+
x, 1.0, 0, -128, 127, torch.int8
1360+
)
1361+
q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
1362+
dq, 1.0, 0, -128, 127, torch.int8
1363+
)
1364+
return q
1365+
1366+
model = TestDqQ()
1367+
m_eager = model.eval()
1368+
ep = torch.export.export(m_eager, (torch.randn(9, 8),))
1369+
edge = to_edge(ep)
1370+
# Check that the dq and q nodes are not touched by the RemoveNoopPass.
1371+
self.assertTrue(
1372+
any(
1373+
"dequantize" in str(node.target)
1374+
for node in edge.exported_program().graph_module.graph.nodes
1375+
)
1376+
)
1377+
self.assertTrue(
1378+
any(
1379+
"quantize" in str(node.target)
1380+
for node in edge.exported_program().graph_module.graph.nodes
1381+
)
1382+
)
1383+
1384+
def test_dq_q_different_qparams(self) -> None:
1385+
class TestDqQDifferentQParam(torch.nn.Module):
1386+
def __init__(self):
1387+
super().__init__()
1388+
1389+
def forward(self, x):
1390+
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1391+
x, 1.0, 0, -128, 127, torch.int8
1392+
)
1393+
slice_copy_output = torch.ops.aten.slice_copy.Tensor(dq, 0, 0)
1394+
q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
1395+
slice_copy_output, 1.0, 0, -127, 127, torch.int8
1396+
)
1397+
return q
1398+
1399+
model = TestDqQDifferentQParam()
1400+
m_eager = model.eval()
1401+
ep = torch.export.export(m_eager, (torch.randn(9, 8),))
1402+
edge = to_edge(ep)
1403+
print(edge.exported_program().graph_module.graph)
1404+
# Check that the dq and q nodes are not touched by the RemoveNoopPass.
1405+
self.assertTrue(
1406+
any(
1407+
"dequantize" in str(node.target)
1408+
for node in edge.exported_program().graph_module.graph.nodes
1409+
)
1410+
)
1411+
self.assertTrue(
1412+
any(
1413+
"quantize" in str(node.target)
1414+
for node in edge.exported_program().graph_module.graph.nodes
1415+
)
1416+
)
1417+
self.assertFalse(
1418+
any(
1419+
"slice" in str(node.target)
1420+
for node in edge.exported_program().graph_module.graph.nodes
1421+
)
1422+
)

0 commit comments

Comments
 (0)