Skip to content

Commit c33932c

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Support unlift=false for per_channel qparams (#868)
Summary: Pull Request resolved: #868 This was a TODO, is now needed to support Saliency model with `unlift=False` Reviewed By: kimishpatel Differential Revision: D50197184 fbshipit-source-id: c172fd77653fd458b9eb6bb56248416cddf8483d
1 parent d6bd012 commit c33932c

File tree

5 files changed

+126
-10
lines changed

5 files changed

+126
-10
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,9 @@ def define_nodes_tensor_inputs_outputs(
458458
)
459459
# Define Weight Node
460460
weight_node = get_input_node(node, input_type_map.node_weight)
461-
weight_quant_params = QuantParams.from_weights(weight_node)
461+
weight_quant_params = QuantParams.from_weights(
462+
weight_node, self._exported_program
463+
)
462464
self.define_tensor(
463465
weight_node,
464466
xnn_graph,

backends/xnnpack/operators/op_conv2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def define_node(
7373
# shape for xnnpack convolution is (oc, height, width, inc/groups), to convert
7474
# to the proper shape, this is essentially a NCHW to NHWC conversion
7575
weight_node = get_input_node(node, 1)
76-
weight_quant_params = QuantParams.from_weights(weight_node)
76+
weight_quant_params = QuantParams.from_weights(
77+
weight_node, self._exported_program
78+
)
7779
self.define_tensor(
7880
weight_node,
7981
xnn_graph,

backends/xnnpack/operators/quant_params.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
import torch
1212
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1313
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
14-
from executorch.backends.xnnpack.utils.utils import check_or_raise, is_param_node
14+
from executorch.backends.xnnpack.utils.utils import (
15+
check_or_raise,
16+
get_param_tensor,
17+
is_param_node,
18+
)
1519
from executorch.exir.dialects._ops import ops as exir_ops
1620
from torch.export import ExportedProgram
1721

@@ -97,7 +101,9 @@ def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams:
97101
)
98102

99103
@classmethod
100-
def from_q_dq_node(cls, quant_node: torch.fx.Node) -> QuantParams:
104+
def from_q_dq_node(
105+
cls, quant_node: torch.fx.Node, ep: Optional[ExportedProgram] = None
106+
) -> QuantParams:
101107
check_or_raise(
102108
is_quant(quant_node) or is_dequant(quant_node),
103109
f"building quantizer from q/dq node but was given node:{quant_node}",
@@ -119,11 +125,18 @@ def from_q_dq_node(cls, quant_node: torch.fx.Node) -> QuantParams:
119125
if per_channel:
120126
assert isinstance(scale, torch.fx.Node) and isinstance(scale.target, str)
121127
assert isinstance(zp, torch.fx.Node) and isinstance(zp.target, str)
122-
# TODO: use get_param_tensor()
123-
scale = getattr(quant_node.graph.owning_module, scale.target)
124-
zp = getattr(quant_node.graph.owning_module, zp.target)
125-
axis = cast(int, quant_node.args[3])
128+
assert (
129+
ep is not None
130+
), "ExportedProgram must be provided to extract per channel params"
131+
132+
def _get_tensor(node):
133+
param = get_param_tensor(ep, node)
134+
assert param is not None, f"Expected to find param tensor for {node}"
135+
return cast(torch.Tensor, param)
126136

137+
scale = _get_tensor(scale)
138+
zp = _get_tensor(zp)
139+
axis = cast(int, quant_node.args[3])
127140
check_or_raise(
128141
bool(
129142
quant_node.args[-1] != torch.uint8
@@ -152,7 +165,9 @@ def from_q_dq_node(cls, quant_node: torch.fx.Node) -> QuantParams:
152165
)
153166

154167
@classmethod
155-
def from_weights(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
168+
def from_weights(
169+
cls, tensor_node: torch.fx.Node, ep: Optional[ExportedProgram] = None
170+
) -> Optional[QuantParams]:
156171
# Ignore transpose for weights
157172
# TODO:T148540997 remove the t_copy/permute_copy check when convert addmm to linear
158173
dq = (
@@ -180,7 +195,7 @@ def from_weights(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
180195
f"q->dq->permute_copy not derived from static weight, input to the q node: {q.all_input_nodes[0]}",
181196
)
182197

183-
return cls.from_q_dq_node(q)
198+
return cls.from_q_dq_node(q, ep)
184199

185200
@classmethod
186201
def from_inputs(

backends/xnnpack/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ python_unittest(
112112
srcs = [
113113
"ops/add.py",
114114
"ops/conv1d.py",
115+
"ops/conv2d.py",
115116
],
116117
deps = [
117118
"//caffe2:torch",

backends/xnnpack/test/ops/conv2d.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from typing import Optional
9+
10+
import torch
11+
12+
from executorch.backends.xnnpack.test.tester import Quantize, Tester
13+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
14+
get_symmetric_quantization_config,
15+
)
16+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
17+
18+
19+
class Conv2d(torch.nn.Module):
20+
def __init__(
21+
self,
22+
in_channels=2,
23+
out_channels=1,
24+
kernel_size=(3, 3),
25+
stride=(2, 2),
26+
padding=(1, 1),
27+
dilation=(1, 1),
28+
groups=1,
29+
bias=True,
30+
padding_mode="zeros",
31+
batches=1,
32+
width=8,
33+
height=8,
34+
):
35+
super().__init__()
36+
self.batches = batches
37+
self.width = width
38+
self.height = height
39+
self.in_channels = in_channels
40+
41+
self.conv = torch.nn.Conv2d(
42+
in_channels=in_channels,
43+
out_channels=out_channels,
44+
kernel_size=kernel_size,
45+
stride=stride,
46+
padding=padding,
47+
dilation=dilation,
48+
groups=groups,
49+
bias=bias,
50+
padding_mode=padding_mode,
51+
)
52+
53+
def forward(self, x):
54+
return self.conv(x)
55+
56+
def get_inputs(self):
57+
return (torch.randn(self.batches, self.in_channels, self.height, self.width),)
58+
59+
60+
class TestConv2d(unittest.TestCase):
61+
def _test(
62+
self, m: torch.nn.Module, quant_config: Optional[QuantizationConfig] = None
63+
):
64+
tester = Tester(m.eval(), m.get_inputs())
65+
66+
if quant_config is not None:
67+
tester = tester.quantize(Quantize(quantization_config=quant_config))
68+
tester.check(["torch.ops.quantized_decomposed"])
69+
70+
(
71+
tester.export()
72+
.check_count({"torch.ops.aten.convolution.default": 1})
73+
.to_edge()
74+
.check_count(
75+
{"executorch_exir_dialects_edge__ops_aten_convolution_default": 1}
76+
)
77+
.partition()
78+
.check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
79+
.check_count({"torch.ops.executorch_call_delegate": 1})
80+
.to_executorch()
81+
.serialize()
82+
.run_method()
83+
.compare_outputs()
84+
)
85+
86+
def test_conv2d(self) -> None:
87+
self._test(Conv2d())
88+
89+
def test_qconv2d(self) -> None:
90+
self._test(Conv2d(), quant_config=get_symmetric_quantization_config())
91+
92+
def test_qconv2d_per_channel(self) -> None:
93+
self._test(
94+
Conv2d(),
95+
quant_config=get_symmetric_quantization_config(is_per_channel=True),
96+
)

0 commit comments

Comments
 (0)