Skip to content

Commit 7bfec8f

Browse files
mcr229facebook-github-bot
authored andcommitted
move quant_params class out of quant_utils (#308)
Summary: Pull Request resolved: #308 Quant Params always felt a little weird being in quant utils. Additionally in the next diff, we will have a dependency on implicit q/dq pass which will cause a circular dependency with quant_utils. So we move QuantParams out of utils and into its own file next to node_visitor Reviewed By: digantdesai, salilsdesai Differential Revision: D48805091 fbshipit-source-id: c9800171840b9e5679fdacdf7e2538796456e8df
1 parent 18ffbcc commit 7bfec8f

File tree

11 files changed

+246
-237
lines changed

11 files changed

+246
-237
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch
1111
from executorch.backends.transforms import get_shape
1212

13+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
14+
1315
from executorch.backends.xnnpack.passes.channels_last_tagged_reshape_pass import (
1416
ChannelsLastTaggedReshapePass,
1517
)
@@ -25,8 +27,6 @@
2527
XNNTensorValue,
2628
XValue,
2729
)
28-
29-
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
3030
from executorch.backends.xnnpack.utils.utils import (
3131
check_or_raise,
3232
get_input_node,

backends/xnnpack/operators/op_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1415
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1516
OutputMinMax,
1617
XNNAdd,
1718
XNNGraph,
1819
XNode,
1920
)
20-
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
2121

2222
from executorch.backends.xnnpack.utils.utils import get_input_node, get_relu_fused_node
2323
from executorch.exir.dialects._ops import ops as exir_ops

backends/xnnpack/operators/op_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1415
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1516
XNNConcatenate2,
1617
XNNConcatenate3,
1718
XNNConcatenate4,
1819
XNNGraph,
1920
XNode,
2021
)
21-
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
2222
from executorch.backends.xnnpack.utils.utils import PERM_NHWC_TO_NCHW
2323
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
2424

backends/xnnpack/operators/op_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1415
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1516
OutputMinMax,
1617
XNNConv2d,
1718
XNNDepthwiseConv2d,
1819
XNNGraph,
1920
XNode,
2021
)
21-
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
2222
from executorch.backends.xnnpack.utils.utils import (
2323
check_or_raise,
2424
get_input_node,

backends/xnnpack/operators/op_dequantize_per_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1415
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1516
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1617
XNNConvert,
1718
XNNGraph,
1819
XNode,
1920
)
20-
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
2121
from executorch.backends.xnnpack.utils.utils import get_input_node
2222

2323

backends/xnnpack/operators/op_multiply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1415
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1516
OutputMinMax,
1617
XNNGraph,
1718
XNNMultiply,
1819
XNode,
1920
)
20-
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
2121

2222
from executorch.backends.xnnpack.utils.utils import get_input_node, get_relu_fused_node
2323
from executorch.exir.dialects._ops import ops as exir_ops

backends/xnnpack/operators/op_quantize_per_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1415
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1516
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1617
XNNConvert,
1718
XNNGraph,
1819
XNode,
1920
)
20-
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
2121
from executorch.backends.xnnpack.utils.utils import get_input_node
2222

2323

backends/xnnpack/operators/op_sub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
NodeVisitor,
1212
register_node_visitor,
1313
)
14+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1415
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1516
OutputMinMax,
1617
XNNGraph,
1718
XNNSubtract,
1819
XNode,
1920
)
20-
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
2121

2222
from executorch.backends.xnnpack.utils.utils import get_input_node, get_relu_fused_node
2323
from executorch.exir.dialects._ops import ops as exir_ops

backends/xnnpack/operators/op_to_copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
NodeVisitor,
1313
register_node_visitor,
1414
)
15+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1516

1617
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1718
XNNGraph,
1819
XNNStaticTranspose,
1920
XNode,
2021
)
21-
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
2222
from executorch.backends.xnnpack.utils.utils import (
2323
check_or_raise,
2424
get_input_node,
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from typing import cast, Optional, Union
10+
11+
import torch
12+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
13+
from executorch.backends.xnnpack.utils.utils import check_or_raise
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
17+
class QuantParams:
18+
"""
19+
QuantParams class, to represent the paramaters and meta data needed
20+
to quantize a tensor. The metadata can technically all be encapsulated
21+
within the quant torch.fx.Node, however, there are some cases in which
22+
nodes which are meant to be quantized for XNNPACK are not quantized
23+
in PyTorch IR, specifically bias nodes. In this case, we can still build
24+
quantizer class to serialize the quantized attributes needed for XNNPACK.
25+
26+
Attributes:
27+
per_channel: Whether this quantization is per channel or per tensor
28+
q_input: node that is the input to this quantization
29+
scale: tensor or float that is used as the quantization scale
30+
zp: tensor or float that is used as the quantization zero point
31+
axis: used for per_channel quantizaiton, representing the axis
32+
dtype: dtype of the type being quantized to
33+
qmin: quantization minimum
34+
qmax: quantization maximum
35+
is_output: whether this is an output node or not
36+
is_input: whether this is an input node or not
37+
"""
38+
39+
def __init__(
40+
self,
41+
per_channel: bool,
42+
q_input: torch.fx.Node,
43+
scale: Union[torch.Tensor, float],
44+
zp: Union[torch.Tensor, float],
45+
axis: int,
46+
dtype: torch.dtype,
47+
qmax: int,
48+
qmin: int,
49+
is_output: bool,
50+
is_input: bool,
51+
is_dynamic: bool = False,
52+
) -> None:
53+
self.per_channel = per_channel
54+
self.q_input = q_input
55+
self.scale = scale
56+
self.zp = zp
57+
self.axis = axis
58+
self.dtype = dtype
59+
self.qmax = qmax
60+
self.qmin = qmin
61+
self.is_output = is_output
62+
self.is_input = is_input
63+
self.is_dynamic = is_dynamic
64+
65+
def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
66+
if self.per_channel:
67+
return exir_ops.edge.quantized_decomposed.quantize_per_channel.default(
68+
tensor, self.scale, self.zp, self.axis, self.qmin, self.qmax, self.dtype
69+
)
70+
else:
71+
return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
72+
tensor, self.scale, self.zp, self.qmin, self.qmax, self.dtype
73+
)
74+
75+
@classmethod
76+
def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams:
77+
q_input = quant_node.args[0] # fp32 input
78+
assert isinstance(q_input, torch.fx.Node)
79+
return cls(
80+
per_channel=False, # True is not valid
81+
q_input=q_input,
82+
scale=0.0, # no need
83+
zp=0.0, # no need
84+
axis=0, # no need
85+
dtype=torch.float32, # will be quantized at runtime
86+
qmax=0, # no need
87+
qmin=0, # no need
88+
is_output=False,
89+
is_input=q_input.op == "placeholder",
90+
is_dynamic=True,
91+
)
92+
93+
@classmethod
94+
def from_q_dq_node(cls, quant_node: torch.fx.Node) -> QuantParams:
95+
check_or_raise(
96+
is_quant(quant_node) or is_dequant(quant_node),
97+
f"building quantizer from q/dq node but was given node:{quant_node}",
98+
)
99+
q_input = quant_node.all_input_nodes[0]
100+
101+
if quant_node.target in [
102+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
103+
]:
104+
return cls._from_dynamic_input_node(q_input)
105+
106+
per_channel = quant_node.target in [
107+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
108+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
109+
]
110+
scale = quant_node.args[1]
111+
zp = quant_node.args[2]
112+
axis = 0
113+
if per_channel:
114+
assert isinstance(scale, torch.fx.Node) and isinstance(scale.target, str)
115+
assert isinstance(zp, torch.fx.Node) and isinstance(zp.target, str)
116+
# TODO: use get_param_tensor()
117+
scale = getattr(quant_node.graph.owning_module, scale.target)
118+
zp = getattr(quant_node.graph.owning_module, zp.target)
119+
axis = cast(int, quant_node.args[3])
120+
121+
check_or_raise(
122+
bool(
123+
quant_node.args[-1] != torch.uint8
124+
or quant_node.args[-1] != torch.quint8
125+
),
126+
"XNNPACK does not support unsigned quantization",
127+
)
128+
dtype = cast(torch.dtype, quant_node.args[-1])
129+
qmax = cast(int, quant_node.args[-2])
130+
qmin = cast(int, quant_node.args[-3])
131+
is_output = any(
132+
user_node.op == "output" for user_node in quant_node.users.keys()
133+
)
134+
is_input = q_input.op == "placeholder"
135+
return cls(
136+
per_channel,
137+
q_input,
138+
scale,
139+
zp,
140+
axis,
141+
dtype,
142+
qmax,
143+
qmin,
144+
is_output,
145+
is_input,
146+
)
147+
148+
@classmethod
149+
def from_weights(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
150+
# Ignore transpose for weights
151+
# TODO:T148540997 remove the t_copy/permute_copy check when convert addmm to linear
152+
dq = (
153+
tensor_node.all_input_nodes[0]
154+
if tensor_node.target
155+
in (
156+
exir_ops.edge.aten.permute_copy.default,
157+
exir_ops.edge.aten.t_copy.default,
158+
)
159+
else tensor_node
160+
)
161+
# check input of t_copy/permute_copy is dequant
162+
if not is_dequant(dq):
163+
return None
164+
165+
# input of dq is q
166+
check_or_raise(is_quant(dq.all_input_nodes[0]), "expected input to dq to be q")
167+
168+
q = dq.all_input_nodes[0]
169+
170+
# replace this with pointing to the actual weight value.
171+
# if no one else uses this weight value then take it out of the toplevel module
172+
check_or_raise(
173+
q.all_input_nodes[0].op in ["get_attr", "placeholder"],
174+
f"q->dq->permute_copy not derived from static weight, input to the q node: {q.all_input_nodes[0]}",
175+
)
176+
177+
return cls.from_q_dq_node(q)
178+
179+
@classmethod
180+
def from_inputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
181+
# tensor_node is quantized if it is produced by a dequant node
182+
if is_dequant(tensor_node):
183+
return cls.from_q_dq_node(tensor_node)
184+
185+
return None
186+
187+
@classmethod
188+
def from_outputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
189+
# tensor_node can also be quantized if it is used as in q -> dq
190+
if len(tensor_node.users) == 1:
191+
q = list(tensor_node.users.keys())[0]
192+
# Check if user is a q node
193+
if is_quant(q):
194+
return cls.from_q_dq_node(q)
195+
196+
return None
197+
198+
@classmethod
199+
def from_bias(
200+
cls,
201+
bias: torch.fx.Node,
202+
weight_quantizer: Optional[QuantParams],
203+
input_quantizer: Optional[QuantParams],
204+
) -> Optional[QuantParams]:
205+
if weight_quantizer is None or input_quantizer is None:
206+
check_or_raise(
207+
weight_quantizer is None and input_quantizer is None,
208+
"Weight and Input should both be quantized",
209+
)
210+
return None
211+
212+
if input_quantizer.is_dynamic:
213+
# No need to quantize bias for dyanamic quantization
214+
return None
215+
216+
check_or_raise(
217+
not input_quantizer.per_channel,
218+
"Input can not be quantized per channel",
219+
)
220+
221+
check_or_raise(
222+
isinstance(input_quantizer.scale, float),
223+
f"q_input scale should be float, but got {input_quantizer.scale}",
224+
)
225+
return cls(
226+
per_channel=weight_quantizer.per_channel,
227+
q_input=bias,
228+
scale=weight_quantizer.scale * cast(float, input_quantizer.scale),
229+
zp=weight_quantizer.zp * 0,
230+
axis=weight_quantizer.axis,
231+
dtype=torch.int32,
232+
qmin=-(2**31),
233+
qmax=(2**31) - 1,
234+
is_output=False,
235+
is_input=False,
236+
)

0 commit comments

Comments
 (0)