Skip to content

Commit d60aef5

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Move quantize IO passes from internal to ExecuTorch (#6686)
Summary: Rationale: code sharing. rather than rewriting Reviewed By: YIWENX14, kirklandsign Differential Revision: D65188297
1 parent 17ad8d3 commit d60aef5

File tree

4 files changed

+441
-0
lines changed

4 files changed

+441
-0
lines changed

exir/passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ python_library(
1616
":normalize_transpose_pass",
1717
":prim_ops_py_registry",
1818
":quant_fusion_pass",
19+
":quantize_io_pass",
1920
":remove_noop_pass",
2021
":replace_aten_with_edge_pass",
2122
":replace_broken_ops_with_function_ops_pass",
@@ -143,6 +144,19 @@ python_library(
143144
],
144145
)
145146

147+
python_library(
148+
name = "quantize_io_pass",
149+
srcs = [
150+
"quantize_io_pass.py",
151+
],
152+
deps = [
153+
"fbsource//third-party/pypi/numpy:numpy",
154+
"//caffe2:torch",
155+
"//executorch/exir:pass_base",
156+
"//executorch/exir/dialects:lib",
157+
],
158+
)
159+
146160
python_library(
147161
name = "memory_planning_pass",
148162
srcs = [

exir/passes/quantize_io_pass.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
import logging
3+
from typing import Any, Dict, List, Optional, Union
4+
5+
import numpy as np
6+
7+
import torch
8+
9+
from executorch.exir import EdgeProgramManager
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
from executorch.exir.pass_base import ExportPass
13+
from executorch.exir.tensor import scalar_type_enum
14+
from torch.fx.passes.infra.pass_base import PassResult
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
def quantize_input(
20+
exported_program, input_index, qparams: Optional[Dict[str, Any]] = None
21+
):
22+
"""
23+
Modify the program to expect quantized input at given index. The input is expected
24+
to be quantizing this input as the first step. Must be called before
25+
permute_input_layout. Returns the scale, zero point, qmin, qmax, and dtype of the
26+
expected quantization.
27+
"""
28+
graph = exported_program.graph_module.graph
29+
name = exported_program.graph_signature.user_inputs[input_index]
30+
placeholders = [n for n in graph.nodes if n.op == "placeholder" and n.name == name]
31+
assert placeholders
32+
target_placeholder = placeholders[0]
33+
34+
if len(target_placeholder.users) != 1:
35+
raise ValueError(f"Input {input_index} has more than one users")
36+
quantize = next(iter(target_placeholder.users))
37+
if (
38+
quantize.target
39+
!= exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
40+
):
41+
raise ValueError(f"Input {input_index} is not used by a quantize op")
42+
43+
# If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op
44+
need_requant = False
45+
if qparams is not None:
46+
assert all(
47+
qparam in qparams for qparam in ["scale", "zp", "dtype"]
48+
), "dtype/scale/zp must be specified in qparam for input requantization"
49+
if qparams["dtype"] != quantize.args[5]:
50+
if any(
51+
dtype
52+
not in [torch.int8, torch.uint8, torch.bool, torch.int16, torch.uint16]
53+
for dtype in [qparams["dtype"], quantize.args[5]]
54+
):
55+
raise ValueError(
56+
f"Only limited data types are supported for requantization, but got {qparams['dtype']} -> {quantize.args[5]}"
57+
)
58+
59+
need_requant = True
60+
elif (
61+
not np.isclose(qparams["scale"], quantize.args[1])
62+
or qparams["zp"] != quantize.args[2]
63+
):
64+
need_requant = True
65+
66+
if need_requant:
67+
assert qparams is not None
68+
dtype = qparams["dtype"]
69+
qmin = torch.iinfo(dtype).min
70+
qmax = torch.iinfo(dtype).max
71+
scale = qparams["scale"]
72+
zero_point = qparams["zp"]
73+
quant_args = (scale, zero_point, qmin, qmax, dtype)
74+
logger.info(
75+
f"Modifying program to requantize quantized input at index {input_index}"
76+
)
77+
logger.info(f"Quantization parameters: {quant_args}")
78+
79+
with exported_program.graph_module.graph.inserting_before(quantize):
80+
input_dequant = exported_program.graph_module.graph.call_function(
81+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
82+
args=(
83+
target_placeholder,
84+
*quant_args,
85+
),
86+
)
87+
input_dequant.meta["input_qparams"] = [
88+
{
89+
"scale": scale,
90+
"zero_point": zero_point,
91+
"qmin": qmin,
92+
"qmax": qmax,
93+
"dtype": dtype,
94+
}
95+
]
96+
input_dequant.meta["val"] = quantize.meta["val"].to(torch.float32)
97+
target_placeholder.meta["val"] = target_placeholder.meta["val"].to(dtype)
98+
quantize.replace_input_with(target_placeholder, input_dequant)
99+
else:
100+
quant_args = quantize.args[1:]
101+
logger.info(f"Modifying program to take quantized input at index {input_index}")
102+
logger.info(f"Quantization parameters: {quant_args}")
103+
104+
target_placeholder.meta["val"] = (
105+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
106+
target_placeholder.meta["val"], *quant_args
107+
)
108+
)
109+
quantize.replace_all_uses_with(quantize.args[0])
110+
111+
exported_program.graph_module.graph.eliminate_dead_code()
112+
return quant_args
113+
114+
115+
def quantize_output(exported_program, output_index):
116+
"""
117+
Modify the program to produce quantized output at given index. The model is expected
118+
to be dequantizing this output as the last step. Must be called before
119+
permute_output_layout. Returns the scale, zero point, qmin, qmax, and dtype of the
120+
output quantization.
121+
"""
122+
graph = exported_program.graph_module.graph
123+
outputs = [n for n in graph.nodes if n.op == "output"]
124+
if len(outputs) != 1:
125+
raise NotImplementedError("Only 1 output node is supported")
126+
127+
output_node = outputs[0]
128+
output_list = list(output_node.args[0])
129+
if output_index >= len(output_list):
130+
raise ValueError(
131+
f"{len(output_list)} outputs available, "
132+
+ f"output index out of bounds: {output_index}"
133+
)
134+
135+
target_output = output_list[output_index]
136+
if (
137+
target_output.target
138+
!= exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
139+
):
140+
raise ValueError("Output {output_index} is not a dequantize op")
141+
142+
dequant = target_output
143+
output_list[output_index] = dequant.args[0]
144+
output_node.args = (output_list,)
145+
dequant_args = dequant.args[1:]
146+
graph.eliminate_dead_code()
147+
148+
logger.info(
149+
f"Modifying program to produce quantized output at index {output_index}"
150+
)
151+
logger.info(f"Dequantization parameters: {dequant_args}")
152+
return dequant_args
153+
154+
155+
def get_config_method_name(
156+
prefix: Optional[str] = "forward",
157+
arg_type: str = "input",
158+
index: int = 0,
159+
key: str = "scale",
160+
):
161+
if prefix is None:
162+
prefix = ""
163+
else:
164+
prefix = prefix + "_"
165+
assert arg_type in ["input", "output"], "arg_type must be either input or output"
166+
assert index >= 0, "index must be non-negative"
167+
assert key in [
168+
"scale",
169+
"zp",
170+
"quant_min",
171+
"quant_max",
172+
"dtype",
173+
], "key must be one of scale, zp, quant_min, quant_max, dtype"
174+
return f"{prefix}{arg_type}{index}_{key}"
175+
176+
177+
class QuantizeInputs(ExportPass):
178+
def __init__(
179+
self,
180+
edge_program_manager: EdgeProgramManager,
181+
quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]],
182+
method_name: Optional[str] = None,
183+
):
184+
super().__init__()
185+
self.edge_program_manager = edge_program_manager
186+
187+
self.quantized_inputs_idx_dict = {}
188+
if isinstance(quantized_inputs_idx, dict):
189+
self.quantized_inputs_idx_dict = quantized_inputs_idx
190+
else:
191+
for idx in quantized_inputs_idx:
192+
self.quantized_inputs_idx_dict[idx] = None
193+
self.param_prefix_name = method_name
194+
195+
def call(self, graph_module: torch.fx.GraphModule):
196+
for i, qparams in self.quantized_inputs_idx_dict.items():
197+
quant_args = quantize_input(
198+
self.edge_program_manager.exported_program(), i, qparams
199+
)
200+
201+
if not self.edge_program_manager._config_methods:
202+
self.edge_program_manager._config_methods = {}
203+
204+
self.edge_program_manager._config_methods[
205+
get_config_method_name(self.param_prefix_name, "input", i, "scale")
206+
] = quant_args[0]
207+
self.edge_program_manager._config_methods[ # pyre-ignore
208+
get_config_method_name(self.param_prefix_name, "input", i, "zp")
209+
] = quant_args[1]
210+
self.edge_program_manager._config_methods[
211+
get_config_method_name(self.param_prefix_name, "input", i, "quant_min")
212+
] = quant_args[2]
213+
self.edge_program_manager._config_methods[
214+
get_config_method_name(self.param_prefix_name, "input", i, "quant_max")
215+
] = quant_args[3]
216+
self.edge_program_manager._config_methods[
217+
get_config_method_name(self.param_prefix_name, "input", i, "dtype")
218+
] = scalar_type_enum(quant_args[4])
219+
return PassResult(graph_module, True)
220+
221+
222+
class QuantizeOutputs(ExportPass):
223+
def __init__(
224+
self,
225+
edge_program_manager: EdgeProgramManager,
226+
quantized_outputs_idx_list: List[int],
227+
method_name: Optional[str] = None,
228+
):
229+
super().__init__()
230+
self.edge_program_manager = edge_program_manager
231+
self.quantized_outputs_idx_list = quantized_outputs_idx_list
232+
self.param_prefix_name = method_name
233+
234+
def call(self, graph_module: torch.fx.GraphModule):
235+
for i in self.quantized_outputs_idx_list:
236+
dequant_args = quantize_output(
237+
self.edge_program_manager.exported_program(), i
238+
) # noqa F841
239+
240+
if not self.edge_program_manager._config_methods:
241+
self.edge_program_manager._config_methods = {}
242+
243+
self.edge_program_manager._config_methods[
244+
get_config_method_name(self.param_prefix_name, "output", i, "scale")
245+
] = dequant_args[0]
246+
self.edge_program_manager._config_methods[ # pyre-ignore
247+
get_config_method_name(self.param_prefix_name, "output", i, "zp")
248+
] = dequant_args[1]
249+
self.edge_program_manager._config_methods[
250+
get_config_method_name(self.param_prefix_name, "output", i, "quant_min")
251+
] = dequant_args[2]
252+
self.edge_program_manager._config_methods[
253+
get_config_method_name(self.param_prefix_name, "output", i, "quant_max")
254+
] = dequant_args[3]
255+
self.edge_program_manager._config_methods[
256+
get_config_method_name(self.param_prefix_name, "output", i, "dtype")
257+
] = scalar_type_enum(dequant_args[4])
258+
259+
return PassResult(graph_module, True)

exir/tests/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,15 @@ python_unittest(
448448
"//executorch/exir:_warnings",
449449
],
450450
)
451+
452+
python_unittest(
453+
name = "quantize_io_pass",
454+
srcs = [
455+
"test_quantize_io_pass.py",
456+
],
457+
deps = [
458+
"//caffe2:torch",
459+
"//executorch/exir:lib",
460+
"//executorch/exir/passes:quantize_io_pass",
461+
],
462+
)

0 commit comments

Comments
 (0)