@@ -42,6 +42,81 @@ def register_passable_op(op):
42
42
passable_ops .append (op )
43
43
44
44
45
+ def insert_rescale_ops_to_int32 (
46
+ tosa_graph : ts .TosaSerializer , inputs : list [TosaArg ], node : Node
47
+ ) -> tuple [list [TosaSerializerTensor ], float ]:
48
+ """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'.
49
+ The scales are adjusted using the smallest scale of all 'nodes'.
50
+
51
+ Returns a list of the rescaled nodes and the scale factor used,
52
+ needed by rescale_node_back_to_int8.
53
+
54
+ This functions is used in serialization to TOSA for target ops that are
55
+ handled by the DQ/D folding pass, which stores the quantization parameters
56
+ in the node meta dict as opposed to 'rescale_nodes_to_int32' which search
57
+ the graph upstream for DQ nodes.
58
+ """
59
+
60
+ tensors = inputs .copy ()
61
+
62
+ # Reshape tensor according to TOSA dim order
63
+ for tensor in tensors :
64
+ dim_order = tensor .dim_order
65
+ tensor .shape = [tensor .shape [i ] for i in dim_order ]
66
+
67
+ qargs = list (cast (dict [int , QuantArgs ], node .meta ["input_qparams" ]).values ())
68
+
69
+ # Scale the int8 quantized input to a common scale in the integer
70
+ # domain
71
+ min_scale = min ([qarg .scale for qarg in qargs ])
72
+ scales = [qarg .scale / min_scale for qarg in qargs ]
73
+
74
+ rescaled_nodes : list [TosaSerializerTensor ] = []
75
+ for tensor , qarg , scale in zip (tensors , qargs , scales ):
76
+ rescaled_nodes .append (
77
+ build_rescale_to_int32 (
78
+ tosa_graph ,
79
+ tensor ,
80
+ qarg .zp ,
81
+ scale ,
82
+ )
83
+ )
84
+ return rescaled_nodes , min_scale
85
+
86
+
87
+ def insert_rescale_node_back_to_int8 (
88
+ tosa_graph : ts .TosaSerializer ,
89
+ last_tensor : TosaArg ,
90
+ scale : float ,
91
+ node : Node ,
92
+ ) -> None :
93
+ """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'.
94
+ Parameters:
95
+ node: The original node that is being handled by the rescales.
96
+ last_tensor:the tosa tensor to rescale back.
97
+ scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32'
98
+ tosa_graph: the tosa_graph to manipulate.
99
+
100
+ This functions is used in serialization to TOSA for target ops that are
101
+ handled by the DQ/D folding pass, which stores the quantization parameters
102
+ in the node meta dict as opposed to 'rescale_node_back_to_int8' which search
103
+ the graph downstream for Q nodes.
104
+ """
105
+ assert len (node .meta ["output_qparams" ]) == 1
106
+
107
+ qargs_out = cast (dict [int , QuantArgs ], node .meta ["output_qparams" ])[0 ]
108
+ output_rescale_scale = scale / qargs_out .scale
109
+
110
+ # Rescale Back to INT8
111
+ build_rescale_from_int32 (
112
+ tosa_graph ,
113
+ last_tensor .name ,
114
+ node .name ,
115
+ qargs_out .zp ,
116
+ output_rescale_scale ,
117
+ )
118
+
119
+
45
120
class QuantArgs (NamedTuple ):
46
121
scale : float
47
122
zp : int
@@ -61,6 +136,20 @@ def quantize_value(self, x):
61
136
def dequantize_value (self , qx : int ) -> float :
62
137
return (qx - self .zp ) * self .scale
63
138
139
+ @classmethod
140
+ def from_operator (cls , op , args ):
141
+ if op in dq_q_ops :
142
+ return cls (
143
+ scale = cast (float , args [1 ]),
144
+ zp = cast (int , args [2 ]),
145
+ qmin = cast (int , args [3 ]),
146
+ qmax = cast (int , args [4 ]),
147
+ dtype = cast (torch .dtype , args [5 ]),
148
+ )
149
+ else :
150
+ # We're only handling per tensor quantization
151
+ raise NotImplementedError
152
+
64
153
65
154
def quantize_value (x , qargs : QuantArgs , dtype = np .int8 ):
66
155
return np .clip (
@@ -77,13 +166,7 @@ def dequantize_value(qx, qargs: QuantArgs):
77
166
def qargs_from_qnode (node : torch .fx .Node ):
78
167
assert node .target in dq_q_ops , f"Op { node } is not a quant node."
79
168
80
- return QuantArgs (
81
- scale = cast (float , node .args [1 ]),
82
- zp = cast (int , node .args [2 ]),
83
- qmin = cast (int , node .args [3 ]),
84
- qmax = cast (int , node .args [4 ]),
85
- dtype = cast (torch .dtype , node .args [5 ]),
86
- )
169
+ return QuantArgs .from_operator (node .target , node .args )
87
170
88
171
89
172
def get_neighbour_quant_args (
@@ -214,8 +297,13 @@ def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs:
214
297
215
298
216
299
def get_quantized_node_output_dtype (node : torch .fx .Node ) -> torch .dtype :
217
- if isinstance (node .target , Callable ) and "tosa" in node .target .__name__ :
218
- return node .meta ["val" ].dtype
300
+ if isinstance (node .target , Callable ) and "output_qparams" in node .meta .keys ():
301
+ # Check if the node has had it's quantization parameters folded
302
+ # and retrieve the dtype from the meta dict in that case.
303
+ assert len (node .meta ["output_qparams" ]) == 1
304
+ qargs = cast (QuantArgs , node .meta ["output_qparams" ][0 ])
305
+ return qargs .dtype
306
+
219
307
if node .target in dq_q_ops :
220
308
return cast (torch .dtype , node .args [5 ])
221
309
0 commit comments