11
11
12
12
import torch
13
13
14
- from executorch .exir import EdgeProgramManager
14
+ from executorch .exir import EdgeProgramManager , ExportedProgram
15
15
from executorch .exir .dialects ._ops import ops as exir_ops
16
16
17
17
from executorch .exir .pass_base import ExportPass
@@ -39,11 +39,33 @@ def quantize_input(
39
39
if len (target_placeholder .users ) != 1 :
40
40
raise ValueError (f"Input { input_index } has more than one users" )
41
41
quantize = next (iter (target_placeholder .users ))
42
+ if quantize .target not in [
43
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
44
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
45
+ ]:
46
+ raise ValueError (
47
+ f"Input { input_index } is not used by a quantize op. It's used by { quantize .target } "
48
+ )
49
+
42
50
if (
43
51
quantize .target
44
- ! = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
52
+ = = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
45
53
):
46
- raise ValueError (f"Input { input_index } is not used by a quantize op" )
54
+ replacement_op_dequant = (
55
+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
56
+ )
57
+ replacement_op_quant = (
58
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
59
+ )
60
+ elif quantize .target == torch .ops .quantized_decomposed .quantize_per_tensor .default :
61
+ replacement_op_dequant = (
62
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default
63
+ )
64
+ replacement_op_quant = (
65
+ torch .ops .quantized_decomposed .quantize_per_tensor .default
66
+ )
67
+ else :
68
+ raise ValueError (f"Invalid quantize op: { quantize .target } " )
47
69
48
70
# If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op
49
71
need_requant = False
@@ -83,7 +105,7 @@ def quantize_input(
83
105
84
106
with exported_program .graph_module .graph .inserting_before (quantize ):
85
107
input_dequant = exported_program .graph_module .graph .call_function (
86
- exir_ops . edge . quantized_decomposed . dequantize_per_tensor . default ,
108
+ replacement_op_dequant ,
87
109
args = (
88
110
target_placeholder ,
89
111
* quant_args ,
@@ -106,10 +128,8 @@ def quantize_input(
106
128
logger .info (f"Modifying program to take quantized input at index { input_index } " )
107
129
logger .info (f"Quantization parameters: { quant_args } " )
108
130
109
- target_placeholder .meta ["val" ] = (
110
- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default (
111
- target_placeholder .meta ["val" ], * quant_args
112
- )
131
+ target_placeholder .meta ["val" ] = replacement_op_quant (
132
+ target_placeholder .meta ["val" ], * quant_args
113
133
)
114
134
quantize .replace_all_uses_with (quantize .args [0 ])
115
135
@@ -138,10 +158,10 @@ def quantize_output(exported_program, output_index):
138
158
)
139
159
140
160
target_output = output_list [output_index ]
141
- if (
142
- target_output . target
143
- != exir_ops . edge .quantized_decomposed .dequantize_per_tensor .default
144
- ) :
161
+ if target_output . target not in [
162
+ exir_ops . edge . quantized_decomposed . dequantize_per_tensor . default ,
163
+ torch . ops .quantized_decomposed .dequantize_per_tensor .default ,
164
+ ] :
145
165
raise ValueError ("Output {output_index} is not a dequantize op" )
146
166
147
167
dequant = target_output
@@ -185,6 +205,7 @@ def __init__(
185
205
edge_program_manager : EdgeProgramManager ,
186
206
quantized_inputs_idx : Union [Dict [int , Dict [str , Any ]], List [int ]],
187
207
method_name : Optional [str ] = None ,
208
+ exported_program : Optional [ExportedProgram ] = None ,
188
209
):
189
210
super ().__init__ ()
190
211
self .edge_program_manager = edge_program_manager
@@ -196,31 +217,49 @@ def __init__(
196
217
for idx in quantized_inputs_idx :
197
218
self .quantized_inputs_idx_dict [idx ] = None
198
219
self .param_prefix_name = method_name
220
+ self .exported_program = exported_program
221
+ self .quant_args = {}
199
222
200
- def call (self , graph_module : torch .fx .GraphModule ):
201
- for i , qparams in self .quantized_inputs_idx_dict .items ():
202
- quant_args = quantize_input (
203
- self .edge_program_manager .exported_program (), i , qparams
204
- )
205
-
223
+ def edge_manager_update_quant_config_method (self , idx , quant_args ):
224
+ if self .edge_program_manager is not None :
206
225
if not self .edge_program_manager ._config_methods :
207
226
self .edge_program_manager ._config_methods = {}
208
227
209
228
self .edge_program_manager ._config_methods [
210
- get_config_method_name (self .param_prefix_name , "input" , i , "scale" )
229
+ get_config_method_name (self .param_prefix_name , "input" , idx , "scale" )
211
230
] = quant_args [0 ]
212
- self .edge_program_manager ._config_methods [ # pyre-ignore
213
- get_config_method_name (self .param_prefix_name , "input" , i , "zp" )
231
+ self .edge_program_manager ._config_methods [
232
+ get_config_method_name (self .param_prefix_name , "input" , idx , "zp" )
214
233
] = quant_args [1 ]
215
234
self .edge_program_manager ._config_methods [
216
- get_config_method_name (self .param_prefix_name , "input" , i , "quant_min" )
235
+ get_config_method_name (
236
+ self .param_prefix_name , "input" , idx , "quant_min"
237
+ )
217
238
] = quant_args [2 ]
218
239
self .edge_program_manager ._config_methods [
219
- get_config_method_name (self .param_prefix_name , "input" , i , "quant_max" )
240
+ get_config_method_name (
241
+ self .param_prefix_name , "input" , idx , "quant_max"
242
+ )
220
243
] = quant_args [3 ]
221
244
self .edge_program_manager ._config_methods [
222
- get_config_method_name (self .param_prefix_name , "input" , i , "dtype" )
245
+ get_config_method_name (self .param_prefix_name , "input" , idx , "dtype" )
223
246
] = scalar_type_enum (quant_args [4 ])
247
+
248
+ def edge_manager_update_quant_config_methods_all (self ):
249
+ if self .edge_program_manager is not None :
250
+ for idx , val in self .quant_args .items ():
251
+ self .edge_manager_update_quant_config_method (idx , val )
252
+
253
+ def call (self , graph_module : torch .fx .GraphModule ):
254
+ for i , qparams in self .quantized_inputs_idx_dict .items ():
255
+ exported_program = (
256
+ self .edge_program_manager .exported_program ()
257
+ if self .edge_program_manager is not None
258
+ else self .exported_program
259
+ )
260
+ self .quant_args [i ] = quantize_input (exported_program , i , qparams )
261
+ self .edge_manager_update_quant_config_method (i , self .quant_args [i ])
262
+
224
263
return PassResult (graph_module , True )
225
264
226
265
@@ -230,35 +269,53 @@ def __init__(
230
269
edge_program_manager : EdgeProgramManager ,
231
270
quantized_outputs_idx_list : List [int ],
232
271
method_name : Optional [str ] = None ,
272
+ exported_program : Optional [ExportedProgram ] = None ,
233
273
):
234
274
super ().__init__ ()
235
275
self .edge_program_manager = edge_program_manager
236
276
self .quantized_outputs_idx_list = quantized_outputs_idx_list
237
277
self .param_prefix_name = method_name
278
+ self .exported_program = exported_program
279
+ self .dequant_args = {}
238
280
239
- def call (self , graph_module : torch .fx .GraphModule ):
240
- for i in self .quantized_outputs_idx_list :
241
- dequant_args = quantize_output (
242
- self .edge_program_manager .exported_program (), i
243
- ) # noqa F841
244
-
281
+ def edge_manager_update_quant_config_method (self , idx , dequant_args ):
282
+ if self .edge_program_manager is not None :
245
283
if not self .edge_program_manager ._config_methods :
246
284
self .edge_program_manager ._config_methods = {}
247
285
248
286
self .edge_program_manager ._config_methods [
249
- get_config_method_name (self .param_prefix_name , "output" , i , "scale" )
287
+ get_config_method_name (self .param_prefix_name , "output" , idx , "scale" )
250
288
] = dequant_args [0 ]
251
- self .edge_program_manager ._config_methods [ # pyre-ignore
252
- get_config_method_name (self .param_prefix_name , "output" , i , "zp" )
289
+ self .edge_program_manager ._config_methods [
290
+ get_config_method_name (self .param_prefix_name , "output" , idx , "zp" )
253
291
] = dequant_args [1 ]
254
292
self .edge_program_manager ._config_methods [
255
- get_config_method_name (self .param_prefix_name , "output" , i , "quant_min" )
293
+ get_config_method_name (
294
+ self .param_prefix_name , "output" , idx , "quant_min"
295
+ )
256
296
] = dequant_args [2 ]
257
297
self .edge_program_manager ._config_methods [
258
- get_config_method_name (self .param_prefix_name , "output" , i , "quant_max" )
298
+ get_config_method_name (
299
+ self .param_prefix_name , "output" , idx , "quant_max"
300
+ )
259
301
] = dequant_args [3 ]
260
302
self .edge_program_manager ._config_methods [
261
- get_config_method_name (self .param_prefix_name , "output" , i , "dtype" )
303
+ get_config_method_name (self .param_prefix_name , "output" , idx , "dtype" )
262
304
] = scalar_type_enum (dequant_args [4 ])
263
305
306
+ def edge_manager_update_quant_config_methods_all (self ):
307
+ if self .edge_program_manager is not None :
308
+ for idx , val in self .dequant_args .items ():
309
+ self .edge_manager_update_quant_config_method (idx , val )
310
+
311
+ def call (self , graph_module : torch .fx .GraphModule ):
312
+ for i in self .quantized_outputs_idx_list :
313
+ exported_program = (
314
+ self .edge_program_manager .exported_program ()
315
+ if self .edge_program_manager is not None
316
+ else self .exported_program
317
+ )
318
+ self .dequant_args [i ] = quantize_output (exported_program , i ) # noqa F841
319
+ self .edge_manager_update_quant_config_method (i , self .dequant_args [i ])
320
+
264
321
return PassResult (graph_module , True )
0 commit comments