Skip to content

Commit ec18ba3

Browse files
tarun292facebook-github-bot
authored andcommitted
Add fuse_dq_q_pass in exir/passes and also add it to HTP backend (#2295)
Summary: There are passes such as these https://fburl.com/code/vs6n4vcv that remove noops from the graph. The problem is that after this pass runs it still leaves in the dq->q nodes in the graph which then potentially get delegated to the backend causing perf regressions. This pass will remove the dq->q ops if their qparams are of the same value. If not it won't touch them. Reviewed By: kimishpatel Differential Revision: D54543323
1 parent 7edd2fa commit ec18ba3

File tree

3 files changed

+216
-28
lines changed

3 files changed

+216
-28
lines changed

exir/passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ python_library(
155155
deps = [
156156
"//caffe2:torch",
157157
"//executorch/exir:pass_base",
158+
"//executorch/exir/dialects:lib",
158159
],
159160
)
160161

exir/passes/remove_noop_pass.py

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,88 @@
66

77
# pyre-strict
88

9+
from typing import List, Tuple
10+
911
import torch
10-
from executorch.exir.pass_base import ExportPass, ProxyValue
11-
from torch.utils import _pytree as pytree
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.fx import GraphModule
15+
16+
_DEQUANT_OPS: Tuple[torch._ops.OpOverload] = (
17+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
18+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
19+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
20+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
21+
)
22+
_QUANT_OPS: Tuple[torch._ops.OpOverload] = (
23+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
24+
torch.ops.quantized_decomposed.quantize_per_channel.default,
25+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
26+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
27+
)
28+
29+
30+
def eliminate_dq_q(
31+
graph_module: GraphModule,
32+
dequant_nodes: List[torch.fx.Node],
33+
check_qparams: bool = True,
34+
) -> None:
35+
for node in dequant_nodes:
36+
assert node.target in _DEQUANT_OPS
37+
for user in list(node.users):
38+
if user.target in _QUANT_OPS:
39+
# Drop the input arg and check that the qparams are the same.
40+
qparams_dq = list(node.args)[1:]
41+
qparams_q = list(user.args)[1:]
42+
if check_qparams and (qparams_dq != qparams_q):
43+
continue
44+
user.replace_all_uses_with(node.args[0])
1245

1346

1447
class RemoveNoopPass(ExportPass):
1548
"""
1649
Removes noops that pass through arguments.
1750
"""
1851

19-
# pyre-ignore
20-
def call_operator(self, op, args, kwargs, meta):
21-
if op not in (
22-
torch.ops.aten.to.dtype,
23-
torch.ops.aten.dropout.default,
24-
torch.ops.aten.slice_copy.Tensor,
25-
):
26-
return super().call_operator(op, args, kwargs, meta)
27-
28-
args_data, kwargs_data = pytree.tree_map_only(
29-
ProxyValue, lambda x: x.data, (args, kwargs)
30-
)
31-
orig_tensor = (
32-
args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
33-
)
34-
35-
if orig_tensor is op(*args_data, **kwargs_data):
36-
return args[0]
37-
38-
if op == torch.ops.aten.slice_copy.Tensor:
39-
result = op(*args_data, **kwargs_data)
40-
if orig_tensor.size() == result.size():
41-
return args[0]
42-
43-
return super().call_operator(op, args, kwargs, meta)
52+
def call(self, graph_module: GraphModule) -> PassResult:
53+
54+
# In this list we'll collect all the dequant nodes that are inputs to ops that
55+
# are removed in this pass and later check for redundant dq->q patterns and
56+
# remove them.
57+
dequant_nodes = []
58+
59+
for node in graph_module.graph.nodes:
60+
if node.op != "call_function":
61+
continue
62+
63+
if node.target not in (
64+
torch.ops.aten.to.dtype,
65+
torch.ops.aten.dropout.default,
66+
torch.ops.aten.slice_copy.Tensor,
67+
):
68+
continue
69+
70+
orig_tensor = node.args[0].meta["val"]
71+
72+
if orig_tensor is node.meta["val"]:
73+
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
74+
# Otherwise, removing only the op will suffice.
75+
if node.args[0].target in _DEQUANT_OPS:
76+
dequant_nodes += [node.args[0]]
77+
node.replace_all_uses_with(node.args[0])
78+
continue
79+
80+
if node.target == torch.ops.aten.slice_copy.Tensor:
81+
if orig_tensor.size() == node.meta["val"].size():
82+
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
83+
# Otherwise, removing only the op will suffice.
84+
if node.args[0].target in _DEQUANT_OPS:
85+
dequant_nodes += [node.args[0]]
86+
node.replace_all_uses_with(node.args[0])
87+
88+
graph_module.graph.eliminate_dead_code()
89+
eliminate_dq_q(graph_module, dequant_nodes)
90+
graph_module.graph.lint()
91+
graph_module.graph.eliminate_dead_code()
92+
93+
return PassResult(graph_module, True)

exir/tests/test_passes.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Import passes
1616
import executorch.exir.memory_planning # noqa
1717
import torch
18-
from executorch.exir import EdgeCompileConfig, memory, to_edge
18+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
1919
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
2020
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2121
from executorch.exir.emit import emit_program
@@ -50,6 +50,12 @@
5050
from functorch.experimental import control_flow
5151

5252
from torch import nn
53+
54+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
55+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
56+
get_symmetric_quantization_config,
57+
XNNPACKQuantizer,
58+
)
5359
from torch.export import export
5460
from torch.fx import GraphModule, subgraph_rewriter
5561
from torch.fx.experimental.proxy_tensor import make_fx
@@ -1244,3 +1250,134 @@ def forward(self, x):
12441250
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
12451251
# return (copy__default, aten_add_tensor)
12461252
self.assertEqual(count_copies(gm), 1)
1253+
1254+
def test_remove_quantized_op_noop_pass(self) -> None:
1255+
class TestAddSliceNoop(torch.nn.Module):
1256+
def __init__(self):
1257+
super().__init__()
1258+
1259+
def forward(self, x):
1260+
x = x + x
1261+
x = x + x[:]
1262+
return x
1263+
1264+
class TestAddSliceNotNoop(torch.nn.Module):
1265+
def __init__(self):
1266+
super().__init__()
1267+
1268+
def forward(self, x):
1269+
x = x + x
1270+
x = x + x[:1]
1271+
return x
1272+
1273+
def count_dq_nodes(gm: torch.fx.GraphModule) -> int:
1274+
return sum(
1275+
(
1276+
node.target
1277+
in (
1278+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1279+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
1280+
)
1281+
)
1282+
for node in gm.graph.nodes
1283+
)
1284+
1285+
def count_q_nodes(gm: torch.fx.GraphModule) -> int:
1286+
return sum(
1287+
(
1288+
node.target
1289+
in (
1290+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
1291+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1292+
)
1293+
)
1294+
for node in gm.graph.nodes
1295+
)
1296+
1297+
def quantize_model(
1298+
m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor]
1299+
) -> Tuple[EdgeProgramManager, int, int]:
1300+
# program capture
1301+
m = torch._export.capture_pre_autograd_graph(
1302+
m_eager,
1303+
example_inputs,
1304+
)
1305+
1306+
quantizer = XNNPACKQuantizer()
1307+
quantization_config = get_symmetric_quantization_config()
1308+
quantizer.set_global(quantization_config)
1309+
m = prepare_pt2e(m, quantizer)
1310+
m = convert_pt2e(m, fold_quantize=True)
1311+
ep = torch.export.export(m, example_inputs)
1312+
dq_nodes_pre = count_dq_nodes(ep.graph_module)
1313+
q_nodes_pre = count_q_nodes(ep.graph_module)
1314+
edge = to_edge(
1315+
ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
1316+
)
1317+
return edge, dq_nodes_pre, q_nodes_pre
1318+
1319+
example_inputs = (torch.randn(9, 8),)
1320+
model = TestAddSliceNoop()
1321+
m_eager = model.eval()
1322+
edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)
1323+
1324+
dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1325+
q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1326+
# One dq and one q node around the slice copy should have been removed.
1327+
self.assertEqual(dq_nodes_pre - dq_nodes_post, 1)
1328+
self.assertEqual(q_nodes_pre - q_nodes_post, 1)
1329+
1330+
# Check that the slice_copy is removed by the RemoveNoopPass.
1331+
for node in edge.exported_program().graph_module.graph.nodes:
1332+
self.assertFalse("slice" in str(node.target))
1333+
1334+
model = TestAddSliceNotNoop()
1335+
m_eager = model.eval()
1336+
edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)
1337+
1338+
dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1339+
q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1340+
# One dq and one q node around the slice copy should have been removed.
1341+
self.assertEqual(dq_nodes_pre, dq_nodes_post)
1342+
self.assertEqual(q_nodes_pre, q_nodes_post)
1343+
1344+
print(edge.exported_program().graph_module.graph)
1345+
# Check that the slice_copy is not removed by the RemoveNoopPass.
1346+
self.assertTrue(
1347+
any(
1348+
"slice" in str(node.target)
1349+
for node in edge.exported_program().graph_module.graph.nodes
1350+
)
1351+
)
1352+
1353+
def test_dq_q_no_op_pass(self) -> None:
1354+
class TestDqQ(torch.nn.Module):
1355+
def __init__(self):
1356+
super().__init__()
1357+
1358+
def forward(self, x):
1359+
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1360+
x, 1.0, 0, -128, 127, torch.int8
1361+
)
1362+
q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
1363+
dq, 1.0, 0, -128, 127, torch.int8
1364+
)
1365+
return q
1366+
1367+
model = TestDqQ()
1368+
m_eager = model.eval()
1369+
ep = torch.export.export(m_eager, (torch.randn(9, 8),))
1370+
edge = to_edge(ep)
1371+
# Check that the dq and q nodes are not touched by the RemoveNoopPass.
1372+
self.assertTrue(
1373+
any(
1374+
"dequantize" in str(node.target)
1375+
for node in edge.exported_program().graph_module.graph.nodes
1376+
)
1377+
)
1378+
self.assertTrue(
1379+
any(
1380+
"quantize" in str(node.target)
1381+
for node in edge.exported_program().graph_module.graph.nodes
1382+
)
1383+
)

0 commit comments

Comments
 (0)