Skip to content

Commit e6f7c03

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
FuseDequantLinearPass to convert dq -> linear into weight_int8pack_mm (#4708)
Summary: Pull Request resolved: #4708 Replaces `dq(weight) -> linear(activation, dq)` with `weight_int8pack_mm` Replaces `dq(weight) -> linear(activation, dq, bias)` with `weight_int8pack_mm -> add` Differential Revision: D60945766
1 parent 6efc222 commit e6f7c03

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

backends/transforms/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,21 @@ runtime.python_library(
7373
],
7474
)
7575

76+
runtime.python_library(
77+
name = "fuse_dequant_linear",
78+
srcs = ["fuse_dequant_linear.py"],
79+
visibility = [
80+
"//executorch/backends/...",
81+
],
82+
deps = [
83+
":utils",
84+
"//caffe2:torch",
85+
"//executorch/exir:pass_base",
86+
"//executorch/exir:sym_util",
87+
"//executorch/exir/dialects:lib",
88+
],
89+
)
90+
7691
runtime.python_library(
7792
name = "fuse_view_copy",
7893
srcs = ["fuse_view_copy.py"],
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
13+
class FuseDequantLinearPass(ExportPass):
14+
"""
15+
Fuses weight dequantize_per_channel nodes with linear nodes into
16+
weight_int8pack_mm nodes, for 8-bit weight-only quantization.
17+
18+
Replaces dq(weight) -> linear(activation, dq) with weight_int8pack_mm
19+
Replaces dq(weight) -> linear(activation, dq, bias) with weight_int8pack_mm -> add
20+
"""
21+
22+
def fuse_dequant_with_linear(
23+
self,
24+
graph_module: torch.fx.GraphModule,
25+
dequant_node: torch.fx.Node,
26+
linear_node: torch.fx.Node,
27+
):
28+
activations = linear_node.args[0]
29+
bias = None
30+
if len(linear_node.args) > 2:
31+
bias = linear_node.args[2]
32+
quant_weight = dequant_node.args[0]
33+
scale = dequant_node.args[1]
34+
35+
with graph_module.graph.inserting_before(linear_node):
36+
weight_int8pack_mm_node = graph_module.graph.create_node(
37+
"call_function",
38+
exir_ops.edge.aten._weight_int8pack_mm.default,
39+
(activations, quant_weight, scale),
40+
)
41+
if bias:
42+
add_node = graph_module.graph.create_node(
43+
"call_function",
44+
exir_ops.edge.aten.add.Tensor,
45+
(weight_int8pack_mm_node, bias),
46+
)
47+
linear_node.replace_all_uses_with(add_node)
48+
else:
49+
linear_node.replace_all_uses_with(weight_int8pack_mm_node)
50+
graph_module.graph.erase_node(linear_node)
51+
graph_module.graph.erase_node(dequant_node)
52+
53+
def is_node_target(self, node, target):
54+
return node.op == "call_function" and node.target == target
55+
56+
def call(self, graph_module: torch.fx.GraphModule):
57+
for node in graph_module.graph.nodes:
58+
if self.is_node_target(node, exir_ops.edge.aten.linear.default):
59+
weight_node = node.args[1]
60+
if self.is_node_target(
61+
weight_node,
62+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
63+
):
64+
# only fuse if weight tensor is int8 packed
65+
quant_weight = weight_node.args[0]
66+
if quant_weight.meta["val"].dtype != torch.int8:
67+
continue
68+
self.fuse_dequant_with_linear(graph_module, weight_node, node)
69+
70+
graph_module.recompile()
71+
graph_module = super().call(graph_module).graph_module
72+
73+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)