Skip to content

Commit 52a4cec

Browse files
mcr229facebook-github-bot
authored andcommitted
linear fused relu in xnnpack delegate (#1551)
Summary: Pull Request resolved: #1551 Add support for linear relu fusion in XNNPACK delegate Reviewed By: balakv504, digantdesai Differential Revision: D52574085 fbshipit-source-id: 88222fc8fdd7525ccb1f56afd45433ec31696669
1 parent 10ef41e commit 52a4cec

File tree

3 files changed

+100
-20
lines changed

3 files changed

+100
-20
lines changed

backends/xnnpack/operators/op_linear.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,20 @@
99
import torch
1010
from executorch.backends.xnnpack.operators.node_visitor import (
1111
get_input_node,
12-
InputTypeToIndex,
1312
NodeVisitor,
1413
register_node_visitor,
1514
)
15+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1616
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
17+
OutputMinMax,
1718
XNNFullyConnected,
1819
XNNGraph,
1920
XNode,
2021
)
22+
from executorch.backends.xnnpack.utils.utils import get_relu_fused_node
2123

2224
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
25+
from executorch.exir.dialects._ops import ops as exir_ops
2326

2427

2528
@register_node_visitor
@@ -36,30 +39,62 @@ def define_node(
3639
vals_to_ids: Dict[torch.fx.Node, int],
3740
debug_handle: int,
3841
) -> None:
39-
input_type_map = (
40-
InputTypeToIndex(node_input=0, node_weight=1, node_bias=2)
41-
if len(node.args) == 3
42-
else InputTypeToIndex(node_input=0, node_weight=1)
43-
)
44-
self.define_nodes_tensor_inputs_outputs(
45-
node, xnn_graph, vals_to_ids, input_type_map=input_type_map
46-
)
47-
48-
# bias
49-
bias_id = (
50-
XNN_INVALID_VALUE_ID
51-
if len(node.args) == 2
52-
else vals_to_ids[get_input_node(node, input_type_map.node_bias)]
53-
)
5442

5543
# input
56-
input_id = vals_to_ids[get_input_node(node, input_type_map.node_input)]
44+
input_node = get_input_node(node, 0)
45+
input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
46+
self.define_tensor(
47+
input_node,
48+
xnn_graph,
49+
vals_to_ids,
50+
quant_params=input_quant_params,
51+
)
52+
input_id = vals_to_ids[input_node]
5753

5854
# filter
59-
filter_id = vals_to_ids[get_input_node(node, input_type_map.node_weight)]
55+
weight_node = get_input_node(node, 1)
56+
weight_quant_params = QuantParams.from_weights(
57+
weight_node, self._exported_program
58+
)
59+
self.define_tensor(
60+
weight_node,
61+
xnn_graph,
62+
vals_to_ids,
63+
quant_params=weight_quant_params,
64+
)
65+
filter_id = vals_to_ids[weight_node]
66+
67+
# bias
68+
if len(node.args) > 2:
69+
bias_node = get_input_node(node, 2)
70+
bias_quant_params = QuantParams.from_bias(
71+
bias_node, weight_quant_params, input_quant_params
72+
)
73+
self.define_tensor(
74+
get_input_node(node, 2),
75+
xnn_graph,
76+
vals_to_ids,
77+
quant_params=bias_quant_params,
78+
)
79+
bias_id = vals_to_ids[bias_node]
80+
else:
81+
bias_id = XNN_INVALID_VALUE_ID
6082

6183
# output
62-
output_id = vals_to_ids[node]
84+
output_node = get_relu_fused_node(node) or node
85+
output_min_max = None
86+
if output_node.target == exir_ops.edge.aten.relu.default:
87+
output_node.meta["XNNPACK_FUSED"] = True
88+
output_min_max = OutputMinMax(output_min=0, output_max="+inf")
89+
90+
output_quant_params = QuantParams.from_outputs(output_node)
91+
self.define_tensor(
92+
output_node,
93+
xnn_graph,
94+
vals_to_ids,
95+
quant_params=output_quant_params,
96+
)
97+
output_id = vals_to_ids[output_node]
6398

6499
ser_node = XNode(
65100
xnode_union=XNNFullyConnected(
@@ -70,5 +105,6 @@ def define_node(
70105
flags=0,
71106
),
72107
debug_handle=debug_handle,
108+
output_min_max=output_min_max,
73109
)
74110
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/passes/tag_implicit_q_dq_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ class TagImplicitQDqPass(XNNPACKPass):
6464
_END_OF_CHAIN_MARKER: True,
6565
}
6666
},
67+
exir_ops.edge.aten.linear.default.name(): {
68+
exir_ops.edge.aten.relu.default.name(): {
69+
_END_OF_CHAIN_MARKER: True,
70+
}
71+
},
6772
}
6873
IS_IMPLICIT_Q_DQ_TAG = "IS_IMPLICIT_Q_DQ_TAG"
6974

backends/xnnpack/test/ops/linear.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,45 @@ def forward(self, x, y):
204204
LinearModule(), inputs, linear_count=3, is_per_channel=True, uses_bias=True
205205
)
206206

207+
def test_fp32_linear_fused_relu(self):
208+
class LinearReluModule(torch.nn.Module):
209+
def __init__(self, in_size, out_size, use_bias):
210+
super().__init__()
211+
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
212+
213+
def forward(self, x):
214+
return torch.nn.functional.relu(self.linear(x))
215+
216+
for use_bias in (True, False):
217+
self._test_linear(
218+
lambda in_size, out_size: LinearReluModule(
219+
in_size,
220+
out_size,
221+
use_bias, # noqa
222+
),
223+
uses_bias=use_bias,
224+
)
225+
226+
def test_qs8_linear_fused_relu(self):
227+
class LinearReluModule(torch.nn.Module):
228+
def __init__(self, in_size, out_size, use_bias):
229+
super().__init__()
230+
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
231+
232+
def forward(self, x):
233+
return torch.nn.functional.relu(self.linear(x))
234+
235+
for use_bias in (True, False):
236+
self._test_linear(
237+
lambda in_size, out_size: LinearReluModule(
238+
in_size,
239+
out_size,
240+
use_bias, # noqa
241+
),
242+
uses_bias=use_bias,
243+
quant=True,
244+
)
245+
207246
def _test_linear(self, make_module, uses_bias, quant=False):
208247
aten_op, edge_op = (
209248
(
@@ -256,7 +295,7 @@ def _test_linear(self, make_module, uses_bias, quant=False):
256295
tester.to_executorch()
257296
tester.serialize()
258297
tester.run_method()
259-
tester.compare_outputs()
298+
tester.compare_outputs(qtol=quant)
260299

261300
def _test_dqlinear(
262301
self, module, inputs, linear_count=1, is_per_channel=False, uses_bias=False

0 commit comments

Comments
 (0)