Skip to content

Commit 6b48e89

Browse files
authored
ModAI changes to export xnnpack delegated non_lowered_server_model
Differential Revision: D70704201 Pull Request resolved: #10989
1 parent 9c1186f commit 6b48e89

File tree

1 file changed

+93
-36
lines changed

1 file changed

+93
-36
lines changed

exir/passes/quantize_io_pass.py

Lines changed: 93 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313

14-
from executorch.exir import EdgeProgramManager
14+
from executorch.exir import EdgeProgramManager, ExportedProgram
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

1717
from executorch.exir.pass_base import ExportPass
@@ -39,11 +39,33 @@ def quantize_input(
3939
if len(target_placeholder.users) != 1:
4040
raise ValueError(f"Input {input_index} has more than one users")
4141
quantize = next(iter(target_placeholder.users))
42+
if quantize.target not in [
43+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
44+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
45+
]:
46+
raise ValueError(
47+
f"Input {input_index} is not used by a quantize op. It's used by {quantize.target}"
48+
)
49+
4250
if (
4351
quantize.target
44-
!= exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
52+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
4553
):
46-
raise ValueError(f"Input {input_index} is not used by a quantize op")
54+
replacement_op_dequant = (
55+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
56+
)
57+
replacement_op_quant = (
58+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
59+
)
60+
elif quantize.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
61+
replacement_op_dequant = (
62+
torch.ops.quantized_decomposed.dequantize_per_tensor.default
63+
)
64+
replacement_op_quant = (
65+
torch.ops.quantized_decomposed.quantize_per_tensor.default
66+
)
67+
else:
68+
raise ValueError(f"Invalid quantize op: {quantize.target}")
4769

4870
# If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op
4971
need_requant = False
@@ -83,7 +105,7 @@ def quantize_input(
83105

84106
with exported_program.graph_module.graph.inserting_before(quantize):
85107
input_dequant = exported_program.graph_module.graph.call_function(
86-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
108+
replacement_op_dequant,
87109
args=(
88110
target_placeholder,
89111
*quant_args,
@@ -106,10 +128,8 @@ def quantize_input(
106128
logger.info(f"Modifying program to take quantized input at index {input_index}")
107129
logger.info(f"Quantization parameters: {quant_args}")
108130

109-
target_placeholder.meta["val"] = (
110-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
111-
target_placeholder.meta["val"], *quant_args
112-
)
131+
target_placeholder.meta["val"] = replacement_op_quant(
132+
target_placeholder.meta["val"], *quant_args
113133
)
114134
quantize.replace_all_uses_with(quantize.args[0])
115135

@@ -138,10 +158,10 @@ def quantize_output(exported_program, output_index):
138158
)
139159

140160
target_output = output_list[output_index]
141-
if (
142-
target_output.target
143-
!= exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
144-
):
161+
if target_output.target not in [
162+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
163+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
164+
]:
145165
raise ValueError("Output {output_index} is not a dequantize op")
146166

147167
dequant = target_output
@@ -185,6 +205,7 @@ def __init__(
185205
edge_program_manager: EdgeProgramManager,
186206
quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]],
187207
method_name: Optional[str] = None,
208+
exported_program: Optional[ExportedProgram] = None,
188209
):
189210
super().__init__()
190211
self.edge_program_manager = edge_program_manager
@@ -196,31 +217,49 @@ def __init__(
196217
for idx in quantized_inputs_idx:
197218
self.quantized_inputs_idx_dict[idx] = None
198219
self.param_prefix_name = method_name
220+
self.exported_program = exported_program
221+
self.quant_args = {}
199222

200-
def call(self, graph_module: torch.fx.GraphModule):
201-
for i, qparams in self.quantized_inputs_idx_dict.items():
202-
quant_args = quantize_input(
203-
self.edge_program_manager.exported_program(), i, qparams
204-
)
205-
223+
def edge_manager_update_quant_config_method(self, idx, quant_args):
224+
if self.edge_program_manager is not None:
206225
if not self.edge_program_manager._config_methods:
207226
self.edge_program_manager._config_methods = {}
208227

209228
self.edge_program_manager._config_methods[
210-
get_config_method_name(self.param_prefix_name, "input", i, "scale")
229+
get_config_method_name(self.param_prefix_name, "input", idx, "scale")
211230
] = quant_args[0]
212-
self.edge_program_manager._config_methods[ # pyre-ignore
213-
get_config_method_name(self.param_prefix_name, "input", i, "zp")
231+
self.edge_program_manager._config_methods[
232+
get_config_method_name(self.param_prefix_name, "input", idx, "zp")
214233
] = quant_args[1]
215234
self.edge_program_manager._config_methods[
216-
get_config_method_name(self.param_prefix_name, "input", i, "quant_min")
235+
get_config_method_name(
236+
self.param_prefix_name, "input", idx, "quant_min"
237+
)
217238
] = quant_args[2]
218239
self.edge_program_manager._config_methods[
219-
get_config_method_name(self.param_prefix_name, "input", i, "quant_max")
240+
get_config_method_name(
241+
self.param_prefix_name, "input", idx, "quant_max"
242+
)
220243
] = quant_args[3]
221244
self.edge_program_manager._config_methods[
222-
get_config_method_name(self.param_prefix_name, "input", i, "dtype")
245+
get_config_method_name(self.param_prefix_name, "input", idx, "dtype")
223246
] = scalar_type_enum(quant_args[4])
247+
248+
def edge_manager_update_quant_config_methods_all(self):
249+
if self.edge_program_manager is not None:
250+
for idx, val in self.quant_args.items():
251+
self.edge_manager_update_quant_config_method(idx, val)
252+
253+
def call(self, graph_module: torch.fx.GraphModule):
254+
for i, qparams in self.quantized_inputs_idx_dict.items():
255+
exported_program = (
256+
self.edge_program_manager.exported_program()
257+
if self.edge_program_manager is not None
258+
else self.exported_program
259+
)
260+
self.quant_args[i] = quantize_input(exported_program, i, qparams)
261+
self.edge_manager_update_quant_config_method(i, self.quant_args[i])
262+
224263
return PassResult(graph_module, True)
225264

226265

@@ -230,35 +269,53 @@ def __init__(
230269
edge_program_manager: EdgeProgramManager,
231270
quantized_outputs_idx_list: List[int],
232271
method_name: Optional[str] = None,
272+
exported_program: Optional[ExportedProgram] = None,
233273
):
234274
super().__init__()
235275
self.edge_program_manager = edge_program_manager
236276
self.quantized_outputs_idx_list = quantized_outputs_idx_list
237277
self.param_prefix_name = method_name
278+
self.exported_program = exported_program
279+
self.dequant_args = {}
238280

239-
def call(self, graph_module: torch.fx.GraphModule):
240-
for i in self.quantized_outputs_idx_list:
241-
dequant_args = quantize_output(
242-
self.edge_program_manager.exported_program(), i
243-
) # noqa F841
244-
281+
def edge_manager_update_quant_config_method(self, idx, dequant_args):
282+
if self.edge_program_manager is not None:
245283
if not self.edge_program_manager._config_methods:
246284
self.edge_program_manager._config_methods = {}
247285

248286
self.edge_program_manager._config_methods[
249-
get_config_method_name(self.param_prefix_name, "output", i, "scale")
287+
get_config_method_name(self.param_prefix_name, "output", idx, "scale")
250288
] = dequant_args[0]
251-
self.edge_program_manager._config_methods[ # pyre-ignore
252-
get_config_method_name(self.param_prefix_name, "output", i, "zp")
289+
self.edge_program_manager._config_methods[
290+
get_config_method_name(self.param_prefix_name, "output", idx, "zp")
253291
] = dequant_args[1]
254292
self.edge_program_manager._config_methods[
255-
get_config_method_name(self.param_prefix_name, "output", i, "quant_min")
293+
get_config_method_name(
294+
self.param_prefix_name, "output", idx, "quant_min"
295+
)
256296
] = dequant_args[2]
257297
self.edge_program_manager._config_methods[
258-
get_config_method_name(self.param_prefix_name, "output", i, "quant_max")
298+
get_config_method_name(
299+
self.param_prefix_name, "output", idx, "quant_max"
300+
)
259301
] = dequant_args[3]
260302
self.edge_program_manager._config_methods[
261-
get_config_method_name(self.param_prefix_name, "output", i, "dtype")
303+
get_config_method_name(self.param_prefix_name, "output", idx, "dtype")
262304
] = scalar_type_enum(dequant_args[4])
263305

306+
def edge_manager_update_quant_config_methods_all(self):
307+
if self.edge_program_manager is not None:
308+
for idx, val in self.dequant_args.items():
309+
self.edge_manager_update_quant_config_method(idx, val)
310+
311+
def call(self, graph_module: torch.fx.GraphModule):
312+
for i in self.quantized_outputs_idx_list:
313+
exported_program = (
314+
self.edge_program_manager.exported_program()
315+
if self.edge_program_manager is not None
316+
else self.exported_program
317+
)
318+
self.dequant_args[i] = quantize_output(exported_program, i) # noqa F841
319+
self.edge_manager_update_quant_config_method(i, self.dequant_args[i])
320+
264321
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)