Skip to content

Commit f004294

Browse files
tarun292facebook-github-bot
authored andcommitted
Add fuse_dq_q_pass in exir/passes and also add it to HTP backend (#2295)
Summary: There are passes such as these https://fburl.com/code/vs6n4vcv that remove noops from the graph. The problem is that after this pass runs it still leaves in the dq->q nodes in the graph which then potentially get delegated to the backend causing perf regressions. This pass will remove the dq->q ops if their qparams are of the same value. If not it won't touch them. Differential Revision: D54543323
1 parent 47b837b commit f004294

File tree

4 files changed

+100
-0
lines changed

4 files changed

+100
-0
lines changed

exir/passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python_library(
1010
deps = [
1111
":const_prop_pass",
1212
":debug_handle_generator_pass",
13+
":fuse_dq_q_pass",
1314
":insert_write_back_for_buffers_pass",
1415
":memory_format_ops_pass",
1516
":memory_planning_pass",
@@ -299,3 +300,15 @@ python_library(
299300
"//executorch/exir/dialects/edge:lib",
300301
],
301302
)
303+
304+
python_library(
305+
name = "fuse_dq_q_pass",
306+
srcs = [
307+
"fuse_dq_q_pass.py",
308+
],
309+
deps = [
310+
"//caffe2:torch",
311+
"//executorch/exir:pass_base",
312+
"//executorch/exir/dialects:lib",
313+
],
314+
)

exir/passes/fuse_dq_q_pass.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from executorch.exir.pass_base import ExportPass
9+
from torch.fx import GraphModule
10+
from torch.fx.passes.infra.pass_base import PassResult
11+
12+
13+
class FuseDQandQPass(ExportPass):
14+
def call(self, graph_module: GraphModule) -> PassResult:
15+
for node in graph_module.graph.nodes:
16+
if node.op != "call_function":
17+
continue
18+
if (
19+
node.target
20+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
21+
):
22+
if all(
23+
user.target
24+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
25+
for user in list(node.users)
26+
):
27+
for user in list(node.users):
28+
# Drop the input arg and check that the qparams are the same.
29+
qparams_dq = list(node.args)[1:]
30+
qparams_q = list(user.args)[1:]
31+
if qparams_dq != qparams_q:
32+
continue
33+
user.replace_all_uses_with(node.args[0])
34+
35+
graph_module.graph.lint()
36+
graph_module.graph.eliminate_dead_code()
37+
return PassResult(graph_module, True)

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ python_unittest(
213213
"//executorch/exir/emit:lib",
214214
"//executorch/exir/passes:constant_prop_pass",
215215
"//executorch/exir/passes:debug_handle_generator_pass",
216+
"//executorch/exir/passes:fuse_dq_q_pass",
216217
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
217218
"//executorch/exir/passes:lib",
218219
"//executorch/exir/passes:remove_graph_asserts_pass",

exir/tests/test_passes.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
3535
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
36+
from executorch.exir.passes.fuse_dq_q_pass import FuseDQandQPass
3637
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
3738
insert_write_back_for_buffers_pass,
3839
)
@@ -50,6 +51,12 @@
5051
from functorch.experimental import control_flow
5152

5253
from torch import nn
54+
55+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
56+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
57+
get_symmetric_quantization_config,
58+
XNNPACKQuantizer,
59+
)
5360
from torch.export import export
5461
from torch.fx import GraphModule, subgraph_rewriter
5562
from torch.fx.experimental.proxy_tensor import make_fx
@@ -1244,3 +1251,45 @@ def forward(self, x):
12441251
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
12451252
# return (copy__default, aten_add_tensor)
12461253
self.assertEqual(count_copies(gm), 1)
1254+
1255+
def test_dq_q_fusion_pass(self) -> None:
1256+
class TestLinearAdd(torch.nn.Module):
1257+
def __init__(self):
1258+
super().__init__()
1259+
self.linear1 = torch.nn.Linear(8, 16)
1260+
1261+
def forward(self, x):
1262+
x1 = self.linear1(x)
1263+
x1 = x1 + x1
1264+
return x1
1265+
1266+
example_inputs = (torch.randn(9, 8),)
1267+
model = TestLinearAdd()
1268+
m_eager = model.eval()
1269+
1270+
# program capture
1271+
m = torch._export.capture_pre_autograd_graph(
1272+
m_eager,
1273+
example_inputs,
1274+
)
1275+
1276+
quantizer = XNNPACKQuantizer()
1277+
quantization_config = get_symmetric_quantization_config()
1278+
quantizer.set_global(quantization_config)
1279+
m = prepare_pt2e(m, quantizer)
1280+
m = convert_pt2e(m, fold_quantize=True)
1281+
ep = torch.export.export(m, example_inputs)
1282+
1283+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False))
1284+
for node in edge.exported_program().graph_module.graph.nodes:
1285+
# Remove add node so that we can test the transform pass which should
1286+
# remove the dq and q nodes.
1287+
if "add" in node.name:
1288+
node.replace_all_uses_with(node.args[0])
1289+
edge.exported_program().graph_module.graph.eliminate_dead_code()
1290+
1291+
len_pre_transform = len(edge.exported_program().graph_module.graph.nodes)
1292+
edge.transform([FuseDQandQPass()])
1293+
len_post_transform = len(edge.exported_program().graph_module.graph.nodes)
1294+
# As one dq and one q node are removed, the number of nodes should be reduced by 2.
1295+
self.assertEqual(len_pre_transform - len_post_transform, 2)

0 commit comments

Comments
 (0)