Skip to content

Commit 68eaac7

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
FuseDequantLinearPass to convert dq -> linear into weight_int8pack_mm (#4708)
Summary: Pull Request resolved: #4708 Replaces `dq(weight) -> linear(activation, dq)` with `weight_int8pack_mm` Replaces `dq(weight) -> linear(activation, dq, bias)` with `weight_int8pack_mm -> add` Reviewed By: copyrightly Differential Revision: D60945766
1 parent f1b741e commit 68eaac7

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

backends/transforms/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,21 @@ runtime.python_library(
7373
],
7474
)
7575

76+
runtime.python_library(
77+
name = "fuse_dequant_linear",
78+
srcs = ["fuse_dequant_linear.py"],
79+
visibility = [
80+
"//executorch/backends/...",
81+
],
82+
deps = [
83+
":utils",
84+
"//caffe2:torch",
85+
"//executorch/exir:pass_base",
86+
"//executorch/exir:sym_util",
87+
"//executorch/exir/dialects:lib",
88+
],
89+
)
90+
7691
runtime.python_library(
7792
name = "fuse_view_copy",
7893
srcs = ["fuse_view_copy.py"],
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
14+
15+
class FuseDequantLinearPass(ExportPass):
16+
"""
17+
Fuses weight dequantize_per_channel nodes with linear nodes into
18+
weight_int8pack_mm nodes, for 8-bit weight-only quantization.
19+
20+
Replaces dq(weight) -> linear(activation, dq) with weight_int8pack_mm
21+
Replaces dq(weight) -> linear(activation, dq, bias) with weight_int8pack_mm -> add
22+
"""
23+
24+
def fuse_dequant_with_linear(
25+
self,
26+
graph_module: torch.fx.GraphModule,
27+
dequant_node: torch.fx.Node,
28+
linear_node: torch.fx.Node,
29+
) -> None:
30+
activations = linear_node.args[0]
31+
bias = None
32+
if len(linear_node.args) > 2:
33+
bias = linear_node.args[2]
34+
quant_weight = dequant_node.args[0]
35+
scale = dequant_node.args[1]
36+
37+
with graph_module.graph.inserting_before(linear_node):
38+
weight_int8pack_mm_node = graph_module.graph.create_node(
39+
"call_function",
40+
exir_ops.edge.aten._weight_int8pack_mm.default,
41+
(activations, quant_weight, scale),
42+
)
43+
if bias:
44+
add_node = graph_module.graph.create_node(
45+
"call_function",
46+
exir_ops.edge.aten.add.Tensor,
47+
(weight_int8pack_mm_node, bias),
48+
)
49+
linear_node.replace_all_uses_with(add_node)
50+
else:
51+
linear_node.replace_all_uses_with(weight_int8pack_mm_node)
52+
graph_module.graph.erase_node(linear_node)
53+
graph_module.graph.erase_node(dequant_node)
54+
55+
def is_node_target(
56+
self, node: torch.fx.Node, target: torch._ops.OperatorBase
57+
) -> bool:
58+
return node.op == "call_function" and node.target == target
59+
60+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
61+
for node in graph_module.graph.nodes:
62+
if self.is_node_target(node, exir_ops.edge.aten.linear.default):
63+
weight_node = node.args[1]
64+
if self.is_node_target(
65+
weight_node,
66+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
67+
):
68+
# only fuse if weight tensor is int8 packed
69+
quant_weight = weight_node.args[0]
70+
if quant_weight.meta["val"].dtype != torch.int8:
71+
continue
72+
self.fuse_dequant_with_linear(graph_module, weight_node, node)
73+
74+
graph_module.recompile()
75+
graph_module = super().call(graph_module).graph_module
76+
77+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)