Skip to content

Commit a7632f3

Browse files
committed
Add functions for usage with DQ/Q folding pass
Reuse the logic from the node visiting quantization handling, but replace the quantization parameter fetching from the node meta values. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I9a7bbf6384284e60118756ec5661f6b11847aba7
1 parent b981b2e commit a7632f3

File tree

1 file changed

+97
-9
lines changed

1 file changed

+97
-9
lines changed

backends/arm/tosa_quant_utils.py

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,81 @@ def register_passable_op(op):
4242
passable_ops.append(op)
4343

4444

45+
def insert_rescale_ops_to_int32(
46+
tosa_graph: ts.TosaSerializer, inputs: list[TosaArg], node: Node
47+
) -> tuple[list[TosaSerializerTensor], float]:
48+
"""Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'.
49+
The scales are adjusted using the smallest scale of all 'nodes'.
50+
51+
Returns a list of the rescaled nodes and the scale factor used,
52+
needed by rescale_node_back_to_int8.
53+
54+
This functions is used in serialization to TOSA for target ops that are
55+
handled by the DQ/D folding pass, which stores the quantization parameters
56+
in the node meta dict as opposed to 'rescale_nodes_to_int32' which search
57+
the graph upstream for DQ nodes.
58+
"""
59+
60+
tensors = inputs.copy()
61+
62+
# Reshape tensor according to TOSA dim order
63+
for tensor in tensors:
64+
dim_order = tensor.dim_order
65+
tensor.shape = [tensor.shape[i] for i in dim_order]
66+
67+
qargs = list(cast(dict[int, QuantArgs], node.meta["input_qparams"]).values())
68+
69+
# Scale the int8 quantized input to a common scale in the integer
70+
# domain
71+
min_scale = min([qarg.scale for qarg in qargs])
72+
scales = [qarg.scale / min_scale for qarg in qargs]
73+
74+
rescaled_nodes: list[TosaSerializerTensor] = []
75+
for tensor, qarg, scale in zip(tensors, qargs, scales):
76+
rescaled_nodes.append(
77+
build_rescale_to_int32(
78+
tosa_graph,
79+
tensor,
80+
qarg.zp,
81+
scale,
82+
)
83+
)
84+
return rescaled_nodes, min_scale
85+
86+
87+
def insert_rescale_node_back_to_int8(
88+
tosa_graph: ts.TosaSerializer,
89+
last_tensor: TosaArg,
90+
scale: float,
91+
node: Node,
92+
) -> None:
93+
"""Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'.
94+
Parameters:
95+
node: The original node that is being handled by the rescales.
96+
last_tensor:the tosa tensor to rescale back.
97+
scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32'
98+
tosa_graph: the tosa_graph to manipulate.
99+
100+
This functions is used in serialization to TOSA for target ops that are
101+
handled by the DQ/D folding pass, which stores the quantization parameters
102+
in the node meta dict as opposed to 'rescale_node_back_to_int8' which search
103+
the graph downstream for Q nodes.
104+
"""
105+
assert len(node.meta["output_qparams"]) == 1
106+
107+
qargs_out = cast(dict[int, QuantArgs], node.meta["output_qparams"])[0]
108+
output_rescale_scale = scale / qargs_out.scale
109+
110+
# Rescale Back to INT8
111+
build_rescale_from_int32(
112+
tosa_graph,
113+
last_tensor.name,
114+
node.name,
115+
qargs_out.zp,
116+
output_rescale_scale,
117+
)
118+
119+
45120
class QuantArgs(NamedTuple):
46121
scale: float
47122
zp: int
@@ -61,6 +136,20 @@ def quantize_value(self, x):
61136
def dequantize_value(self, qx: int) -> float:
62137
return (qx - self.zp) * self.scale
63138

139+
@classmethod
140+
def from_operator(cls, op, args):
141+
if op in dq_q_ops:
142+
return cls(
143+
scale=cast(float, args[1]),
144+
zp=cast(int, args[2]),
145+
qmin=cast(int, args[3]),
146+
qmax=cast(int, args[4]),
147+
dtype=cast(torch.dtype, args[5]),
148+
)
149+
else:
150+
# We're only handling per tensor quantization
151+
raise NotImplementedError
152+
64153

65154
def quantize_value(x, qargs: QuantArgs, dtype=np.int8):
66155
return np.clip(
@@ -77,13 +166,7 @@ def dequantize_value(qx, qargs: QuantArgs):
77166
def qargs_from_qnode(node: torch.fx.Node):
78167
assert node.target in dq_q_ops, f"Op {node} is not a quant node."
79168

80-
return QuantArgs(
81-
scale=cast(float, node.args[1]),
82-
zp=cast(int, node.args[2]),
83-
qmin=cast(int, node.args[3]),
84-
qmax=cast(int, node.args[4]),
85-
dtype=cast(torch.dtype, node.args[5]),
86-
)
169+
return QuantArgs.from_operator(node.target, node.args)
87170

88171

89172
def get_neighbour_quant_args(
@@ -214,8 +297,13 @@ def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs:
214297

215298

216299
def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype:
217-
if isinstance(node.target, Callable) and "tosa" in node.target.__name__:
218-
return node.meta["val"].dtype
300+
if isinstance(node.target, Callable) and "output_qparams" in node.meta.keys():
301+
# Check if the node has had it's quantization parameters folded
302+
# and retrieve the dtype from the meta dict in that case.
303+
assert len(node.meta["output_qparams"]) == 1
304+
qargs = cast(QuantArgs, node.meta["output_qparams"][0])
305+
return qargs.dtype
306+
219307
if node.target in dq_q_ops:
220308
return cast(torch.dtype, node.args[5])
221309

0 commit comments

Comments
 (0)