|
| 1 | +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. |
| 2 | +import logging |
| 3 | +from typing import Any, Dict, List, Optional, Union |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +import torch |
| 8 | + |
| 9 | +from executorch.exir import EdgeProgramManager |
| 10 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 11 | + |
| 12 | +from executorch.exir.pass_base import ExportPass |
| 13 | +from executorch.exir.tensor import scalar_type_enum |
| 14 | +from torch.fx.passes.infra.pass_base import PassResult |
| 15 | + |
| 16 | +logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +def quantize_input( |
| 20 | + exported_program, input_index, qparams: Optional[Dict[str, Any]] = None |
| 21 | +): |
| 22 | + """ |
| 23 | + Modify the program to expect quantized input at given index. The input is expected |
| 24 | + to be quantizing this input as the first step. Must be called before |
| 25 | + permute_input_layout. Returns the scale, zero point, qmin, qmax, and dtype of the |
| 26 | + expected quantization. |
| 27 | + """ |
| 28 | + graph = exported_program.graph_module.graph |
| 29 | + name = exported_program.graph_signature.user_inputs[input_index] |
| 30 | + placeholders = [n for n in graph.nodes if n.op == "placeholder" and n.name == name] |
| 31 | + assert placeholders |
| 32 | + target_placeholder = placeholders[0] |
| 33 | + |
| 34 | + if len(target_placeholder.users) != 1: |
| 35 | + raise ValueError(f"Input {input_index} has more than one users") |
| 36 | + quantize = next(iter(target_placeholder.users)) |
| 37 | + if ( |
| 38 | + quantize.target |
| 39 | + != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default |
| 40 | + ): |
| 41 | + raise ValueError(f"Input {input_index} is not used by a quantize op") |
| 42 | + |
| 43 | + # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op |
| 44 | + need_requant = False |
| 45 | + if qparams is not None: |
| 46 | + assert all( |
| 47 | + qparam in qparams for qparam in ["scale", "zp", "dtype"] |
| 48 | + ), "dtype/scale/zp must be specified in qparam for input requantization" |
| 49 | + if qparams["dtype"] != quantize.args[5]: |
| 50 | + if any( |
| 51 | + dtype |
| 52 | + not in [torch.int8, torch.uint8, torch.bool, torch.int16, torch.uint16] |
| 53 | + for dtype in [qparams["dtype"], quantize.args[5]] |
| 54 | + ): |
| 55 | + raise ValueError( |
| 56 | + f"Only limited data types are supported for requantization, but got {qparams['dtype']} -> {quantize.args[5]}" |
| 57 | + ) |
| 58 | + |
| 59 | + need_requant = True |
| 60 | + elif ( |
| 61 | + not np.isclose(qparams["scale"], quantize.args[1]) |
| 62 | + or qparams["zp"] != quantize.args[2] |
| 63 | + ): |
| 64 | + need_requant = True |
| 65 | + |
| 66 | + if need_requant: |
| 67 | + assert qparams is not None |
| 68 | + dtype = qparams["dtype"] |
| 69 | + qmin = torch.iinfo(dtype).min |
| 70 | + qmax = torch.iinfo(dtype).max |
| 71 | + scale = qparams["scale"] |
| 72 | + zero_point = qparams["zp"] |
| 73 | + quant_args = (scale, zero_point, qmin, qmax, dtype) |
| 74 | + logger.info( |
| 75 | + f"Modifying program to requantize quantized input at index {input_index}" |
| 76 | + ) |
| 77 | + logger.info(f"Quantization parameters: {quant_args}") |
| 78 | + |
| 79 | + with exported_program.graph_module.graph.inserting_before(quantize): |
| 80 | + input_dequant = exported_program.graph_module.graph.call_function( |
| 81 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 82 | + args=( |
| 83 | + target_placeholder, |
| 84 | + *quant_args, |
| 85 | + ), |
| 86 | + ) |
| 87 | + input_dequant.meta["input_qparams"] = [ |
| 88 | + { |
| 89 | + "scale": scale, |
| 90 | + "zero_point": zero_point, |
| 91 | + "qmin": qmin, |
| 92 | + "qmax": qmax, |
| 93 | + "dtype": dtype, |
| 94 | + } |
| 95 | + ] |
| 96 | + input_dequant.meta["val"] = quantize.meta["val"].to(torch.float32) |
| 97 | + target_placeholder.meta["val"] = target_placeholder.meta["val"].to(dtype) |
| 98 | + quantize.replace_input_with(target_placeholder, input_dequant) |
| 99 | + else: |
| 100 | + quant_args = quantize.args[1:] |
| 101 | + logger.info(f"Modifying program to take quantized input at index {input_index}") |
| 102 | + logger.info(f"Quantization parameters: {quant_args}") |
| 103 | + |
| 104 | + target_placeholder.meta["val"] = ( |
| 105 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( |
| 106 | + target_placeholder.meta["val"], *quant_args |
| 107 | + ) |
| 108 | + ) |
| 109 | + quantize.replace_all_uses_with(quantize.args[0]) |
| 110 | + |
| 111 | + exported_program.graph_module.graph.eliminate_dead_code() |
| 112 | + return quant_args |
| 113 | + |
| 114 | + |
| 115 | +def quantize_output(exported_program, output_index): |
| 116 | + """ |
| 117 | + Modify the program to produce quantized output at given index. The model is expected |
| 118 | + to be dequantizing this output as the last step. Must be called before |
| 119 | + permute_output_layout. Returns the scale, zero point, qmin, qmax, and dtype of the |
| 120 | + output quantization. |
| 121 | + """ |
| 122 | + graph = exported_program.graph_module.graph |
| 123 | + outputs = [n for n in graph.nodes if n.op == "output"] |
| 124 | + if len(outputs) != 1: |
| 125 | + raise NotImplementedError("Only 1 output node is supported") |
| 126 | + |
| 127 | + output_node = outputs[0] |
| 128 | + output_list = list(output_node.args[0]) |
| 129 | + if output_index >= len(output_list): |
| 130 | + raise ValueError( |
| 131 | + f"{len(output_list)} outputs available, " |
| 132 | + + f"output index out of bounds: {output_index}" |
| 133 | + ) |
| 134 | + |
| 135 | + target_output = output_list[output_index] |
| 136 | + if ( |
| 137 | + target_output.target |
| 138 | + != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default |
| 139 | + ): |
| 140 | + raise ValueError("Output {output_index} is not a dequantize op") |
| 141 | + |
| 142 | + dequant = target_output |
| 143 | + output_list[output_index] = dequant.args[0] |
| 144 | + output_node.args = (output_list,) |
| 145 | + dequant_args = dequant.args[1:] |
| 146 | + graph.eliminate_dead_code() |
| 147 | + |
| 148 | + logger.info( |
| 149 | + f"Modifying program to produce quantized output at index {output_index}" |
| 150 | + ) |
| 151 | + logger.info(f"Dequantization parameters: {dequant_args}") |
| 152 | + return dequant_args |
| 153 | + |
| 154 | + |
| 155 | +def get_config_method_name( |
| 156 | + prefix: Optional[str] = "forward", |
| 157 | + arg_type: str = "input", |
| 158 | + index: int = 0, |
| 159 | + key: str = "scale", |
| 160 | +): |
| 161 | + if prefix is None: |
| 162 | + prefix = "" |
| 163 | + else: |
| 164 | + prefix = prefix + "_" |
| 165 | + assert arg_type in ["input", "output"], "arg_type must be either input or output" |
| 166 | + assert index >= 0, "index must be non-negative" |
| 167 | + assert key in [ |
| 168 | + "scale", |
| 169 | + "zp", |
| 170 | + "quant_min", |
| 171 | + "quant_max", |
| 172 | + "dtype", |
| 173 | + ], "key must be one of scale, zp, quant_min, quant_max, dtype" |
| 174 | + return f"{prefix}{arg_type}{index}_{key}" |
| 175 | + |
| 176 | + |
| 177 | +class QuantizeInputs(ExportPass): |
| 178 | + def __init__( |
| 179 | + self, |
| 180 | + edge_program_manager: EdgeProgramManager, |
| 181 | + quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]], |
| 182 | + method_name: Optional[str] = None, |
| 183 | + ): |
| 184 | + super().__init__() |
| 185 | + self.edge_program_manager = edge_program_manager |
| 186 | + |
| 187 | + self.quantized_inputs_idx_dict = {} |
| 188 | + if isinstance(quantized_inputs_idx, dict): |
| 189 | + self.quantized_inputs_idx_dict = quantized_inputs_idx |
| 190 | + else: |
| 191 | + for idx in quantized_inputs_idx: |
| 192 | + self.quantized_inputs_idx_dict[idx] = None |
| 193 | + self.param_prefix_name = method_name |
| 194 | + |
| 195 | + def call(self, graph_module: torch.fx.GraphModule): |
| 196 | + for i, qparams in self.quantized_inputs_idx_dict.items(): |
| 197 | + quant_args = quantize_input( |
| 198 | + self.edge_program_manager.exported_program(), i, qparams |
| 199 | + ) |
| 200 | + |
| 201 | + if not self.edge_program_manager._config_methods: |
| 202 | + self.edge_program_manager._config_methods = {} |
| 203 | + |
| 204 | + self.edge_program_manager._config_methods[ |
| 205 | + get_config_method_name(self.param_prefix_name, "input", i, "scale") |
| 206 | + ] = quant_args[0] |
| 207 | + self.edge_program_manager._config_methods[ |
| 208 | + get_config_method_name(self.param_prefix_name, "input", i, "zp") |
| 209 | + ] = quant_args[1] |
| 210 | + self.edge_program_manager._config_methods[ |
| 211 | + get_config_method_name(self.param_prefix_name, "input", i, "quant_min") |
| 212 | + ] = quant_args[2] |
| 213 | + self.edge_program_manager._config_methods[ |
| 214 | + get_config_method_name(self.param_prefix_name, "input", i, "quant_max") |
| 215 | + ] = quant_args[3] |
| 216 | + self.edge_program_manager._config_methods[ |
| 217 | + get_config_method_name(self.param_prefix_name, "input", i, "dtype") |
| 218 | + ] = scalar_type_enum(quant_args[4]) |
| 219 | + return PassResult(graph_module, True) |
| 220 | + |
| 221 | + |
| 222 | +class QuantizeOutputs(ExportPass): |
| 223 | + def __init__( |
| 224 | + self, |
| 225 | + edge_program_manager: EdgeProgramManager, |
| 226 | + quantized_outputs_idx_list: List[int], |
| 227 | + method_name: Optional[str] = None, |
| 228 | + ): |
| 229 | + super().__init__() |
| 230 | + self.edge_program_manager = edge_program_manager |
| 231 | + self.quantized_outputs_idx_list = quantized_outputs_idx_list |
| 232 | + self.param_prefix_name = method_name |
| 233 | + |
| 234 | + def call(self, graph_module: torch.fx.GraphModule): |
| 235 | + for i in self.quantized_outputs_idx_list: |
| 236 | + dequant_args = quantize_output( |
| 237 | + self.edge_program_manager.exported_program(), i |
| 238 | + ) # noqa F841 |
| 239 | + |
| 240 | + if not self.edge_program_manager._config_methods: |
| 241 | + self.edge_program_manager._config_methods = {} |
| 242 | + |
| 243 | + self.edge_program_manager._config_methods[ |
| 244 | + get_config_method_name(self.param_prefix_name, "output", i, "scale") |
| 245 | + ] = dequant_args[0] |
| 246 | + self.edge_program_manager._config_methods[ |
| 247 | + get_config_method_name(self.param_prefix_name, "output", i, "zp") |
| 248 | + ] = dequant_args[1] |
| 249 | + self.edge_program_manager._config_methods[ |
| 250 | + get_config_method_name(self.param_prefix_name, "output", i, "quant_min") |
| 251 | + ] = dequant_args[2] |
| 252 | + self.edge_program_manager._config_methods[ |
| 253 | + get_config_method_name(self.param_prefix_name, "output", i, "quant_max") |
| 254 | + ] = dequant_args[3] |
| 255 | + self.edge_program_manager._config_methods[ |
| 256 | + get_config_method_name(self.param_prefix_name, "output", i, "dtype") |
| 257 | + ] = scalar_type_enum(dequant_args[4]) |
| 258 | + |
| 259 | + return PassResult(graph_module, True) |
0 commit comments