|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
6 | 6 | from enum import Enum
|
7 |
| -from typing import Optional, Tuple |
| 7 | +from typing import List, Optional, Tuple, Union |
8 | 8 |
|
9 | 9 | import torch
|
10 | 10 | from executorch.backends.arm.arm_backend import (
|
|
15 | 15 | from executorch.backends.arm.arm_partitioner import ArmPartitioner
|
16 | 16 |
|
17 | 17 | from executorch.backends.arm.test.tosautil.tosa_test_utils import (
|
| 18 | + QuantizationParams, |
18 | 19 | TosaProfile,
|
19 | 20 | TosaTestUtils,
|
20 | 21 | )
|
|
32 | 33 | get_symmetric_quantization_config,
|
33 | 34 | XNNPACKQuantizer,
|
34 | 35 | )
|
| 36 | +from torch.export import ExportedProgram |
35 | 37 |
|
36 | 38 |
|
37 | 39 | class ArmBackendSelector(Enum):
|
@@ -61,6 +63,7 @@ def __init__(
|
61 | 63 | TosaProfile.BI or TosaProfile.MI
|
62 | 64 | """
|
63 | 65 | self.tosa_test_util = None
|
| 66 | + self.is_quantized = profile == TosaProfile.BI |
64 | 67 | if backend == ArmBackendSelector.TOSA:
|
65 | 68 | self.tosa_test_util = TosaTestUtils(profile=profile)
|
66 | 69 | # The spec below tiggers arm_backend.py to output two files:
|
@@ -119,54 +122,121 @@ def run_method(
|
119 | 122 | ), "self.tosa_test_util is not initialized, cannot use run_method()"
|
120 | 123 | inputs_to_run = inputs or self.inputs
|
121 | 124 |
|
122 |
| - # TODO: we can't possible need to use all these stages?? |
123 |
| - export_stage = self.stages[ |
124 |
| - self.stage_name(Export) |
125 |
| - ] # this is what XNNpack use to get quant params |
126 |
| - toedge_stage = self.stages[ |
127 |
| - self.stage_name(ToEdge) |
128 |
| - ] # this is what get_input_quantization_params use to get quant params |
129 |
| - partition_stage = self.stages[ |
130 |
| - self.stage_name(Partition) |
131 |
| - ] # this is what tosa_ref_dump_inputs use.... |
132 |
| - |
133 |
| - # TODO: I'd prefer to use this TOSA buffer instead of output.tosa, |
134 |
| - # generated by arm_backend.py. The issue is that we're still depending |
135 |
| - # on desc.json, which is created from TosaSerializer class, not from |
136 |
| - # the serialized TOSA buffer. Leave this here for review purposes. |
137 |
| - # ts_serialized = self._get_serialized_tosa_buffer( # unused |
138 |
| - # partition_stage.artifact |
139 |
| - # ) |
140 |
| - |
141 |
| - # This is where the torch reference output is calculated and set |
142 |
| - # TODO: This sets self.quantization_scale, which is duplicates |
143 |
| - # self.tosa_test_util.quantization.output.scales (?). Fixme. |
144 |
| - ( |
145 |
| - self.reference_output, |
146 |
| - self.quantization_scale, |
147 |
| - ) = self._calculate_reference_output(export_stage.artifact, inputs_to_run) |
148 |
| - |
149 |
| - # Convert the torch inputs to something TOSA ref model can use |
150 |
| - tensor_names_and_inputs_np = self.tosa_test_util.convert_inputs_to_tosa( |
151 |
| - partition_stage.artifact, toedge_stage.artifact, inputs_to_run |
| 125 | + export_stage = self.stages[self.stage_name(Export)] |
| 126 | + |
| 127 | + (input_names, qp_input) = self._get_input_params(export_stage.artifact) |
| 128 | + (output_name, qp_output) = self._get_output_param(export_stage.artifact) |
| 129 | + |
| 130 | + # Calculate the reference output using the original module or the quant |
| 131 | + # module. self.quantization_scale is used by compare_outputs() to |
| 132 | + # calculate the tolerance |
| 133 | + self.quantization_scale = None if qp_output is None else qp_output.scale |
| 134 | + if self.is_quantized: |
| 135 | + module_for_ref = self.stages[self.stage_name(Quantize)].artifact |
| 136 | + else: |
| 137 | + module_for_ref = self.original_module |
| 138 | + self.reference_output = self._calculate_reference_output( |
| 139 | + module_for_ref, inputs_to_run |
152 | 140 | )
|
153 | 141 |
|
154 | 142 | # Run the TOSA ref model to get the output tensor, which will be
|
155 | 143 | # compared to the torch output in compare_outputs()
|
156 | 144 | self.stage_output = self.tosa_test_util.run_tosa_ref_model(
|
157 |
| - tensor_names_and_inputs_np |
| 145 | + params_input=(input_names, qp_input), |
| 146 | + param_output=(output_name, qp_output), |
| 147 | + inputs=inputs_to_run, |
158 | 148 | )
|
159 | 149 |
|
160 | 150 | return self
|
161 | 151 |
|
162 |
| - def _get_serialized_tosa_buffer(self, partition_stage: Partition) -> bytes: |
| 152 | + def _get_input_params( |
| 153 | + self, program: ExportedProgram |
| 154 | + ) -> Tuple[str, Union[List[QuantizationParams], List[None]]]: |
163 | 155 | """
|
164 |
| - This is just a prototype... |
165 |
| - Todo: |
166 |
| - * The "_0" indicates that there are many lowered modules. Loop it! |
167 |
| - * There's probably a better way to get this buffer. An API? Yes, |
168 |
| - it seems the serialize stage does this for you... |
| 156 | + Get name and optionally quantization parameters for the inputs to this |
| 157 | + model. |
| 158 | +
|
| 159 | + Args: |
| 160 | + program (ExportedProgram): The program to get input parameters from |
| 161 | + Returns: |
| 162 | + Tuple[str, Optional[QuantizationParams]]: A tuple containing the |
| 163 | + input node names and their quantization parameters. |
| 164 | + """ |
| 165 | + input_names = [] |
| 166 | + # E.g. bias and weights are 'placeholders' as well. This is used to |
| 167 | + # get only the use inputs. |
| 168 | + usr_inputs = program.graph_signature.user_inputs |
| 169 | + for node in program.graph.nodes: |
| 170 | + if node.op == "placeholder" and node.name in usr_inputs: |
| 171 | + input_names.append(node.name) |
| 172 | + continue |
| 173 | + |
| 174 | + if self.is_quantized: |
| 175 | + quant_params = [] |
| 176 | + for node in program.graph.nodes: |
| 177 | + if ( |
| 178 | + node.target |
| 179 | + == torch.ops.quantized_decomposed.quantize_per_tensor.default |
| 180 | + and node.args[0].name in input_names |
| 181 | + ): |
| 182 | + qp = QuantizationParams( |
| 183 | + node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] |
| 184 | + ) |
| 185 | + quant_params.append(qp) |
| 186 | + if len(quant_params) == len( |
| 187 | + input_names |
| 188 | + ): # break early if we have all the inputs quantized parameters |
| 189 | + break |
| 190 | + assert len(quant_params) != 0, "Quantization paramerters not found" |
| 191 | + return (input_names, quant_params) |
| 192 | + else: |
| 193 | + return (input_names, len(input_names) * [None]) # return a list of None's |
| 194 | + |
| 195 | + def _get_output_param( |
| 196 | + self, program: ExportedProgram |
| 197 | + ) -> Tuple[str, Union[QuantizationParams, None]]: |
169 | 198 | """
|
170 |
| - return partition_stage._edge_programs[ |
171 |
| - "forward" |
172 |
| - ]._graph_module.lowered_module_0.processed_bytes |
| 199 | + Get name and optionally quantization parameters for the inputs to this |
| 200 | + model. |
| 201 | +
|
| 202 | + Args: |
| 203 | + program (ExportedProgram): The program to get output parameters from. |
| 204 | + Returns: |
| 205 | + Tuple[str, Optional[QuantizationParams]]: A tuple containing the |
| 206 | + output node name and its quantization parameters. |
| 207 | + """ |
| 208 | + output_node = None |
| 209 | + for node in program.graph.nodes: |
| 210 | + if node.op == "output": |
| 211 | + output_node = node |
| 212 | + break |
| 213 | + |
| 214 | + if self.is_quantized: |
| 215 | + quant_params = None |
| 216 | + for node in program.graph.nodes: |
| 217 | + if ( |
| 218 | + node.target |
| 219 | + == torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| 220 | + and node == output_node.args[0][0] |
| 221 | + ): |
| 222 | + quant_params = QuantizationParams( |
| 223 | + node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] |
| 224 | + ) |
| 225 | + break # break early, there's only one output node |
| 226 | + assert quant_params is not None, "Quantization paramerters not found" |
| 227 | + return (output_node.name, quant_params) |
| 228 | + else: |
| 229 | + return (output_node.name, None) |
| 230 | + |
| 231 | + @staticmethod |
| 232 | + def _calculate_reference_output( |
| 233 | + module: Union[torch.fx.GraphModule, torch.nn.Module], inputs |
| 234 | + ) -> torch.Tensor: |
| 235 | + """ |
| 236 | + Note: I'd prefer to use the base class method here, but since it use the |
| 237 | + exported program, I can't. The partitioner stage clears the state_dict |
| 238 | + of the exported program, which causes an issue when evaluating the |
| 239 | + module. |
| 240 | + """ |
| 241 | + |
| 242 | + return module.forward(*inputs) |
0 commit comments