Skip to content

Commit b4836ed

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Move quantize IO passes from BoltNN to ExecuTorch
Summary: Rationale: code sharing. rather than rewriting Differential Revision: D65188297
1 parent 179d495 commit b4836ed

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed

exir/passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ python_library(
1616
":normalize_transpose_pass",
1717
":prim_ops_py_registry",
1818
":quant_fusion_pass",
19+
":quantize_io_pass",
1920
":remove_noop_pass",
2021
":replace_aten_with_edge_pass",
2122
":replace_broken_ops_with_function_ops_pass",
@@ -143,6 +144,19 @@ python_library(
143144
],
144145
)
145146

147+
python_library(
148+
name = "quantize_io_pass",
149+
srcs = [
150+
"quantize_io_pass.py",
151+
],
152+
deps = [
153+
"fbsource//third-party/pypi/numpy:numpy",
154+
"//caffe2:torch",
155+
"//executorch/exir:pass_base",
156+
"//executorch/exir/dialects:lib",
157+
],
158+
)
159+
146160
python_library(
147161
name = "memory_planning_pass",
148162
srcs = [

exir/passes/quantize_io_pass.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
import logging
3+
from typing import Any, Dict, List, Optional, Union
4+
5+
import numpy as np
6+
7+
import torch
8+
9+
from executorch.exir import EdgeProgramManager
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
from executorch.exir.pass_base import ExportPass
13+
from executorch.exir.tensor import scalar_type_enum
14+
from torch.fx.passes.infra.pass_base import PassResult
15+
16+
logger = logging.getLogger(__name__)
17+
18+
def quantize_input(
19+
exported_program, input_index, qparams: Optional[Dict[str, Any]] = None
20+
):
21+
"""
22+
Modify the program to expect quantized input at given index. The input is expected
23+
to be quantizing this input as the first step. Must be called before
24+
permute_input_layout. Returns the scale, zero point, qmin, qmax, and dtype of the
25+
expected quantization.
26+
"""
27+
graph = exported_program.graph_module.graph
28+
name = exported_program.graph_signature.user_inputs[input_index]
29+
placeholders = [n for n in graph.nodes if n.op == "placeholder" and n.name == name]
30+
assert placeholders
31+
target_placeholder = placeholders[0]
32+
33+
if len(target_placeholder.users) != 1:
34+
raise ValueError(f"Input {input_index} has more than one users")
35+
quantize = next(iter(target_placeholder.users))
36+
if (
37+
quantize.target
38+
!= exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
39+
):
40+
raise ValueError(f"Input {input_index} is not used by a quantize op")
41+
42+
# If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op
43+
need_requant = False
44+
if qparams is not None:
45+
assert all(
46+
qparam in qparams for qparam in ["scale", "zp", "dtype"]
47+
), "dtype/scale/zp must be specified in qparam for input requantization"
48+
if qparams["dtype"] != quantize.args[5]:
49+
if any(
50+
dtype
51+
not in [torch.int8, torch.uint8, torch.bool, torch.int16, torch.uint16]
52+
for dtype in [qparams["dtype"], quantize.args[5]]
53+
):
54+
raise ValueError(
55+
f"Only limited data types are supported for requantization, but got {qparams['dtype']} -> {quantize.args[5]}"
56+
)
57+
58+
need_requant = True
59+
elif (
60+
not np.isclose(qparams["scale"], quantize.args[1])
61+
or qparams["zp"] != quantize.args[2]
62+
):
63+
need_requant = True
64+
65+
if need_requant:
66+
assert qparams is not None
67+
dtype = qparams["dtype"]
68+
qmin = torch.iinfo(dtype).min
69+
qmax = torch.iinfo(dtype).max
70+
scale = qparams["scale"]
71+
zero_point = qparams["zp"]
72+
quant_args = (scale, zero_point, qmin, qmax, dtype)
73+
logger.info(
74+
f"Modifying program to requantize quantized input at index {input_index}"
75+
)
76+
logger.info(f"Quantization parameters: {quant_args}")
77+
78+
with exported_program.graph_module.graph.inserting_before(quantize):
79+
input_dequant = exported_program.graph_module.graph.call_function(
80+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
81+
args=(
82+
target_placeholder,
83+
*quant_args,
84+
),
85+
)
86+
input_dequant.meta["input_qparams"] = [
87+
{
88+
"scale": scale,
89+
"zero_point": zero_point,
90+
"qmin": qmin,
91+
"qmax": qmax,
92+
"dtype": dtype,
93+
}
94+
]
95+
input_dequant.meta["val"] = quantize.meta["val"].to(torch.float32)
96+
target_placeholder.meta["val"] = target_placeholder.meta["val"].to(dtype)
97+
quantize.replace_input_with(target_placeholder, input_dequant)
98+
else:
99+
quant_args = quantize.args[1:]
100+
logger.info(f"Modifying program to take quantized input at index {input_index}")
101+
logger.info(f"Quantization parameters: {quant_args}")
102+
103+
target_placeholder.meta["val"] = (
104+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
105+
target_placeholder.meta["val"], *quant_args
106+
)
107+
)
108+
quantize.replace_all_uses_with(quantize.args[0])
109+
110+
exported_program.graph_module.graph.eliminate_dead_code()
111+
return quant_args
112+
113+
114+
def quantize_output(exported_program, output_index):
115+
"""
116+
Modify the program to produce quantized output at given index. The model is expected
117+
to be dequantizing this output as the last step. Must be called before
118+
permute_output_layout. Returns the scale, zero point, qmin, qmax, and dtype of the
119+
output quantization.
120+
"""
121+
graph = exported_program.graph_module.graph
122+
outputs = [n for n in graph.nodes if n.op == "output"]
123+
if len(outputs) != 1:
124+
raise NotImplementedError("Only 1 output node is supported")
125+
126+
output_node = outputs[0]
127+
output_list = list(output_node.args[0])
128+
if output_index >= len(output_list):
129+
raise ValueError(
130+
f"{len(output_list)} outputs available, "
131+
+ f"output index out of bounds: {output_index}"
132+
)
133+
134+
target_output = output_list[output_index]
135+
if (
136+
target_output.target
137+
!= exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
138+
):
139+
raise ValueError("Output {output_index} is not a dequantize op")
140+
141+
dequant = target_output
142+
output_list[output_index] = dequant.args[0]
143+
output_node.args = (output_list,)
144+
dequant_args = dequant.args[1:]
145+
graph.eliminate_dead_code()
146+
147+
logger.info(
148+
f"Modifying program to produce quantized output at index {output_index}"
149+
)
150+
logger.info(f"Quantization parameters: {dequant_args}")
151+
return dequant_args
152+
153+
class QuantizeInputs(ExportPass):
154+
def __init__(
155+
self,
156+
edge_program_manager: EdgeProgramManager,
157+
quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]],
158+
method_name: Optional[str] = None,
159+
):
160+
super().__init__()
161+
self.edge_program_manager = edge_program_manager
162+
163+
self.quantized_inputs_idx_dict = {}
164+
if isinstance(quantized_inputs_idx, dict):
165+
self.quantized_inputs_idx_dict = quantized_inputs_idx
166+
else:
167+
for idx in quantized_inputs_idx:
168+
self.quantized_inputs_idx_dict[idx] = None
169+
self.param_prefix_name = "" if method_name is None else method_name + "_"
170+
171+
def call(self, graph_module: torch.fx.GraphModule):
172+
for i, qparams in self.quantized_inputs_idx_dict.items():
173+
quant_args = quantize_input(
174+
self.edge_program_manager.exported_program(), i, qparams
175+
)
176+
177+
if not self.edge_program_manager._config_methods:
178+
self.edge_program_manager._config_methods = {}
179+
180+
self.edge_program_manager._config_methods[self.param_prefix_name + f"input{i}_scale"] = quant_args[0]
181+
self.edge_program_manager._config_methods[self.param_prefix_name + f"input{i}_zp"] = quant_args[1] # pyre-ignore
182+
self.edge_program_manager._config_methods[
183+
self.param_prefix_name + f"input{i}_quant_min"
184+
] = quant_args[2]
185+
self.edge_program_manager._config_methods[
186+
self.param_prefix_name + f"input{i}_quant_max"
187+
] = quant_args[3]
188+
self.edge_program_manager._config_methods[
189+
self.param_prefix_name + f"input{i}_dtype"
190+
] = scalar_type_enum(quant_args[4])
191+
return PassResult(graph_module, True)
192+
193+
class QuantizeOutputs(ExportPass):
194+
def __init__(
195+
self,
196+
edge_program_manager: EdgeProgramManager,
197+
quantized_outputs_idx_list: List[int],
198+
method_name: Optional[str] = None,
199+
):
200+
super().__init__()
201+
self.edge_program_manager = edge_program_manager
202+
self.quantized_outputs_idx_list = quantized_outputs_idx_list
203+
self.param_prefix_name = "" if method_name is None else method_name + "_"
204+
205+
def call(self, graph_module: torch.fx.GraphModule):
206+
for i in self.quantized_outputs_idx_list:
207+
dequant_args = quantize_output(self.edge_program_manager.exported_program(), i) # noqa F841
208+
209+
if not self.edge_program_manager._config_methods:
210+
self.edge_program_manager._config_methods = {}
211+
212+
self.edge_program_manager._config_methods[
213+
self.param_prefix_name + f"output{i}_scale"
214+
] = dequant_args[0]
215+
self.edge_program_manager._config_methods[self.param_prefix_name + f"output{i}_zp"] = dequant_args[1] # pyre-ignore
216+
217+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)